mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
[misc] fix parser (#9730)
This commit is contained in:
@@ -30,21 +30,6 @@ from .training_args import TrainingArguments
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def validate_args(
|
||||
data_args: DataArguments,
|
||||
model_args: ModelArguments,
|
||||
training_args: TrainingArguments,
|
||||
sample_args: SampleArguments,
|
||||
):
|
||||
"""Validate arguments."""
|
||||
if (
|
||||
model_args.quant_config is not None
|
||||
and training_args.dist_config is not None
|
||||
and training_args.dist_config.name == "deepspeed"
|
||||
):
|
||||
raise ValueError("Quantization is not supported with deepspeed backend.")
|
||||
|
||||
|
||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
@@ -71,8 +56,6 @@ def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
validate_args(*parsed_args)
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user