release v0.1.4
Former-commit-id: 81f84aaf2e120e39edb28ef42893939fc9a184e2
This commit is contained in:
@@ -32,9 +32,9 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
parser = HfArgumentParser((
|
||||
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
@@ -51,7 +51,7 @@ def parse_infer_args(
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -79,6 +79,12 @@ def get_train_args(
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
|
||||
assert not (training_args.max_steps == -1 and data_args.streaming), \
|
||||
"Please specify `max_steps` in streaming mode."
|
||||
|
||||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
||||
"Streaming mode does not support evaluation currently."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
@@ -108,12 +114,6 @@ def get_train_args(
|
||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||
data_args.dev_ratio = 0
|
||||
|
||||
assert not (training_args.max_steps == -1 and data_args.streaming), \
|
||||
"Please specify `max_steps` in streaming mode."
|
||||
|
||||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
||||
"Streaming mode does not support evaluation currently."
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
|
||||
Reference in New Issue
Block a user