[infer] set env for vllm ascend (#7745)
This commit is contained in:
@@ -91,6 +91,14 @@ def _set_transformers_logging() -> None:
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
|
||||
def _set_env_vars() -> None:
|
||||
if is_torch_npu_available():
|
||||
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
|
||||
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
|
||||
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def _verify_model_args(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -279,12 +287,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
@@ -407,6 +416,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
@@ -428,9 +438,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user