Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
@@ -19,20 +19,39 @@ from llmtuner.hparams import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
parser = HfArgumentParser((
|
||||
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
|
||||
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -73,13 +92,22 @@ def get_train_args(
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
|
||||
data_args.streaming = False
|
||||
|
||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||
data_args.dev_ratio = 0
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
@@ -106,17 +134,7 @@ def get_train_args(
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
@@ -128,7 +146,4 @@ def get_infer_args(
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
Reference in New Issue
Block a user