add pre-training script

Former-commit-id: 935d58de2b3a2eadc4f0bed28c3ad7dee32e9fd5
This commit is contained in:
hiyouga
2023-05-29 21:37:22 +08:00
parent 304be6dc28
commit 33fee45217
6 changed files with 159 additions and 21 deletions

View File

@@ -24,7 +24,7 @@ from utils import (
def main():
# prepare pretrained model and dataset
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")

80
src/train_pt.py Normal file
View File

@@ -0,0 +1,80 @@
# coding=utf-8
# Implements several parameter-efficient pre-training method for LLaMA.
# This code is inspired by
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
import math
from utils import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DataCollatorForLLaMA,
PeftTrainer,
LogCallback,
plot_loss
)
def main():
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PeftTrainer(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
**trainer_kwargs
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@@ -18,14 +18,14 @@ from utils import (
def main():
# prepare pretrained model and dataset
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
training_args.remove_unused_columns = False # Important for pairwise dataset
training_args.remove_unused_columns = False # important for pairwise dataset
# Split the dataset
if training_args.do_train:

View File

@@ -7,7 +7,7 @@ from .common import (
from .data_collator import DataCollatorForLLaMA
from .peft_trainer import LogCallback
from .peft_trainer import PeftTrainer, LogCallback
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA

View File

@@ -130,7 +130,7 @@ def load_pretrained(
model_args: ModelArguments,
finetuning_args: Optional[FinetuningArguments] = None,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
r"""
Loads pretrained model and tokenizer.
@@ -142,11 +142,14 @@ def load_pretrained(
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
finetuning_args = FinetuningArguments.load_from_json(
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
)
else:
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method."
tokenizer = LlamaTokenizer.from_pretrained(
model_args.model_name_or_path,
@@ -207,7 +210,7 @@ def load_pretrained(
def prepare_args(
stage: Literal["sft", "rm", "ppo"]
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
@@ -230,7 +233,7 @@ def prepare_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True in RM and PPO stages.")
raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
@@ -290,7 +293,7 @@ def prepare_data(
cache_dir=model_args.cache_dir
)
elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) # support json, jsonl and csv
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
extension = dataset_attr.file_name.split(".")[-1]
if dataset_attr.file_sha1 is not None:
@@ -299,7 +302,7 @@ def prepare_data(
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset(
extension,
extension if extension in ["csv", "json"] else "text",
data_files=data_file,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None
@@ -313,6 +316,9 @@ def prepare_data(
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
if dataset.column_names[0] == "text": # for plaintext (in pre-training)
dataset = dataset.rename_column("text", getattr(dataset_attr, "prompt_column"))
dummy_data = [None] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
@@ -340,7 +346,7 @@ def preprocess_data(
tokenizer: PreTrainedTokenizer,
data_args: DataTrainingArguments,
training_args: Seq2SeqTrainingArguments,
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:
column_names = list(dataset.column_names)
@@ -363,7 +369,7 @@ def preprocess_data(
yield prompt, answer
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<s>??`
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
text_ids = tokenizer(examples["prompt"])["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
@@ -395,7 +401,7 @@ def preprocess_data(
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_evaluation_dataset(examples):
def preprocess_unsupervised_dataset(examples):
# build inputs with format `X <s>` and labels with format `Y <s>`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
@@ -436,7 +442,7 @@ def preprocess_data(
model_inputs["reject_ids"].append(reject_ids)
return model_inputs
def print_sft_dataset_example(example):
def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
print("label_ids:\n{}".format(example["labels"]))
@@ -450,19 +456,19 @@ def preprocess_data(
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
def print_ppo_dataset_example(example):
def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
if stage == "pt":
preprocess_function = preprocess_pretrain_dataset
elif stage == "sft":
preprocess_function = preprocess_evaluation_dataset \
preprocess_function = preprocess_unsupervised_dataset \
if training_args.predict_with_generate else preprocess_supervised_dataset
elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
preprocess_function = preprocess_evaluation_dataset
preprocess_function = preprocess_unsupervised_dataset
with training_args.main_process_first(desc="dataset map pre-processing"):
dataset = dataset.map(
@@ -474,11 +480,13 @@ def preprocess_data(
desc="Running tokenizer on dataset"
)
if stage == "sft":
print_sft_dataset_example(dataset[0])
if stage == "pt":
print_unsupervised_dataset_example(dataset[0])
elif stage == "sft":
print_supervised_dataset_example(dataset[0])
elif stage == "rm":
print_pairwise_dataset_example(dataset[0])
elif stage == "ppo":
print_ppo_dataset_example(dataset[0])
print_unsupervised_dataset_example(dataset[0])
return dataset