[misc] upgrade cli (#7714)
This commit is contained in:
@@ -390,8 +390,10 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
# Setup logging
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if model_args.infer_backend == "vllm":
|
||||
if finetuning_args.stage != "sft":
|
||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||
@@ -408,6 +410,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
# Post-process model arguments
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
model_args.device_map = {"": torch.device("cpu")}
|
||||
if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default
|
||||
@@ -421,8 +424,10 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
# Setup logging
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user