[infer] set env for vllm ascend (#7745)

This commit is contained in:
hoshi-hiyouga
2025-04-17 01:08:55 +08:00
committed by GitHub
parent 2e518f255f
commit d222f63cb7
5 changed files with 28 additions and 21 deletions

View File

@@ -17,12 +17,12 @@ from typing import TYPE_CHECKING, Any
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging
from ..extras.misc import infer_optim_dtype, is_env_enabled
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
@@ -95,10 +95,6 @@ def patch_config(
else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)