refactor export, fix #1190

Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
This commit is contained in:
hiyouga
2023-10-15 16:01:48 +08:00
parent 68330eab2a
commit c2e84d4558
9 changed files with 52 additions and 49 deletions

View File

@@ -5,7 +5,6 @@ import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
@@ -13,8 +12,7 @@ from llmtuner.hparams import (
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
GeneratingArguments
)
@@ -39,16 +37,14 @@ def parse_train_args(
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
GeneratingArguments
))
return _parse_args(parser, args)
@@ -77,10 +73,9 @@ def get_train_args(
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
GeneratingArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@@ -96,36 +91,36 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
if general_args.stage != "pt" and data_args.template is None:
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if general_args.stage != "sft" and training_args.predict_with_generate:
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage in ["rm", "dpo"]:
if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if general_args.stage == "ppo" and data_args.streaming:
if finetuning_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if general_args.stage == "ppo" and model_args.shift_attn:
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if training_args.max_steps == -1 and data_args.streaming:
@@ -205,7 +200,7 @@ def get_train_args(
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(