support DPO training (2305.18290)

Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
This commit is contained in:
hiyouga
2023-08-11 03:02:53 +08:00
parent 72dfd74005
commit ca719a8697
33 changed files with 513 additions and 192 deletions

View File

@@ -19,7 +19,7 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
@@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
parser = HfArgumentParser((
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
))
return _parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
))
return _parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@@ -68,7 +95,7 @@ def get_train_args(
data_args.init_for_training()
if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
@@ -76,6 +103,15 @@ def get_train_args(
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stage can only be performed at training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
@@ -133,12 +169,17 @@ def get_train_args(
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, general_args
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":