refactor export, fix #1190
Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
This commit is contained in:
@@ -5,7 +5,6 @@ import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -13,8 +12,7 @@ from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
|
||||
@@ -39,16 +37,14 @@ def parse_train_args(
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
@@ -77,10 +73,9 @@ def get_train_args(
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -96,36 +91,36 @@ def get_train_args(
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
data_args.init_for_training()
|
||||
|
||||
if general_args.stage != "pt" and data_args.template is None:
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||
|
||||
if general_args.stage in ["rm", "dpo"]:
|
||||
if finetuning_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and data_args.streaming:
|
||||
if finetuning_args.stage == "ppo" and data_args.streaming:
|
||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.shift_attn:
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
@@ -205,7 +200,7 @@ def get_train_args(
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
|
||||
Reference in New Issue
Block a user