add pre-training script
Former-commit-id: 935d58de2b3a2eadc4f0bed28c3ad7dee32e9fd5
This commit is contained in:
@@ -24,7 +24,7 @@ from utils import (
|
||||
|
||||
def main():
|
||||
|
||||
# prepare pretrained model and dataset
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
|
||||
80
src/train_pt.py
Normal file
80
src/train_pt.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# coding=utf-8
|
||||
# Implements several parameter-efficient pre-training method for LLaMA.
|
||||
# This code is inspired by
|
||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
|
||||
import math
|
||||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
DataCollatorForLLaMA,
|
||||
PeftTrainer,
|
||||
LogCallback,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -18,14 +18,14 @@ from utils import (
|
||||
|
||||
def main():
|
||||
|
||||
# prepare pretrained model and dataset
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
|
||||
|
||||
training_args.remove_unused_columns = False # Important for pairwise dataset
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
|
||||
@@ -7,7 +7,7 @@ from .common import (
|
||||
|
||||
from .data_collator import DataCollatorForLLaMA
|
||||
|
||||
from .peft_trainer import LogCallback
|
||||
from .peft_trainer import PeftTrainer, LogCallback
|
||||
|
||||
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
|
||||
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
|
||||
|
||||
@@ -130,7 +130,7 @@ def load_pretrained(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: Optional[FinetuningArguments] = None,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
@@ -142,11 +142,14 @@ def load_pretrained(
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
|
||||
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
|
||||
finetuning_args = FinetuningArguments.load_from_json(
|
||||
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
|
||||
|
||||
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@@ -207,7 +210,7 @@ def load_pretrained(
|
||||
|
||||
|
||||
def prepare_args(
|
||||
stage: Literal["sft", "rm", "ppo"]
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
||||
@@ -230,7 +233,7 @@ def prepare_args(
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
if stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True in RM and PPO stages.")
|
||||
raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
@@ -290,7 +293,7 @@ def prepare_data(
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) # support json, jsonl and csv
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
||||
extension = dataset_attr.file_name.split(".")[-1]
|
||||
|
||||
if dataset_attr.file_sha1 is not None:
|
||||
@@ -299,7 +302,7 @@ def prepare_data(
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
extension,
|
||||
extension if extension in ["csv", "json"] else "text",
|
||||
data_files=data_file,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
@@ -313,6 +316,9 @@ def prepare_data(
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
if dataset.column_names[0] == "text": # for plaintext (in pre-training)
|
||||
dataset = dataset.rename_column("text", getattr(dataset_attr, "prompt_column"))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
@@ -340,7 +346,7 @@ def preprocess_data(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
data_args: DataTrainingArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
@@ -363,7 +369,7 @@ def preprocess_data(
|
||||
yield prompt, answer
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<s>??`
|
||||
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
|
||||
text_ids = tokenizer(examples["prompt"])["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
@@ -395,7 +401,7 @@ def preprocess_data(
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_evaluation_dataset(examples):
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `X <s>` and labels with format `Y <s>`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
@@ -436,7 +442,7 @@ def preprocess_data(
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_sft_dataset_example(example):
|
||||
def print_supervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
@@ -450,19 +456,19 @@ def preprocess_data(
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
|
||||
|
||||
def print_ppo_dataset_example(example):
|
||||
def print_unsupervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_evaluation_dataset \
|
||||
preprocess_function = preprocess_unsupervised_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
preprocess_function = preprocess_evaluation_dataset
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
dataset = dataset.map(
|
||||
@@ -474,11 +480,13 @@ def preprocess_data(
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
if stage == "sft":
|
||||
print_sft_dataset_example(dataset[0])
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(dataset[0])
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(dataset[0])
|
||||
elif stage == "ppo":
|
||||
print_ppo_dataset_example(dataset[0])
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user