refactor dataset_attr, add eos in pt, fix #757

Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
hiyouga
2023-09-01 19:00:45 +08:00
parent 93be211f80
commit e5b72c6a77
19 changed files with 108 additions and 126 deletions

View File

@@ -5,6 +5,7 @@ import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
@@ -110,6 +111,11 @@ def get_train_args(
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
@@ -166,6 +172,7 @@ def get_train_args(
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
@@ -186,18 +193,6 @@ def get_train_args(
else:
model_args.compute_dtype = torch.float16
# transfer training stage to dataset stage
dataset_stage = general_args.stage
if general_args.stage == "ppo":
dataset_stage = "sft"
elif general_args.stage == "dpo":
dataset_stage = "rm"
for dataset_attr in data_args.dataset_list:
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
raise ValueError("Dataset {} is not supported for the stage {}"
.format(dataset_attr.dataset_name, general_args.stage))
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary: