Initial commit

Former-commit-id: 5ca8e1d63727e7bcb8cab16542c763c47e48184a
This commit is contained in:
hiyouga
2023-05-28 18:09:04 +08:00
commit 17024ebc1a
29 changed files with 2399 additions and 0 deletions

0
src/__init__.py Normal file
View File

66
src/cli_demo.py Normal file
View File

@@ -0,0 +1,66 @@
# coding=utf-8
# Implements stream chat in command line for LLaMA fine-tuned with PEFT.
# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint
import torch
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser
def main():
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
def predict(query, history: list):
inputs = tokenizer([query], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = {
"do_sample": True,
"top_p": 0.9,
"top_k": 40,
"temperature": 0.7,
"num_beams": 1,
"max_new_tokens": 256,
"repetition_penalty": 1.5
}
with torch.no_grad():
generation_output = model.generate(**inputs, **gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
history = history + [(query, response)]
return response, history
history = []
print("欢迎使用 LLaMA-7B 模型输入内容即可对话clear清空对话历史stop终止程序")
while True:
try:
query = input("\nInput: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
continue
response, history = predict(query, history)
print("LLaMA-7B:", response)
if __name__ == "__main__":
main()

23
src/export_model.py Normal file
View File

@@ -0,0 +1,23 @@
# coding=utf-8
# Exports the fine-tuned LLaMA model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from transformers import HfArgumentParser, TrainingArguments
from utils import ModelArguments, load_pretrained
def main():
parser = HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
model.save_pretrained(training_args.output_dir, max_shard_size="1GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
if __name__ == "__main__":
main()

80
src/train_ppo.py Normal file
View File

@@ -0,0 +1,80 @@
# coding=utf-8
# Implements parameter-efficient PPO training of fine-tuned LLaMA.
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
import math
from torch.optim import AdamW
from transformers.optimization import get_scheduler
from trl import PPOConfig
from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
DataCollatorForLLaMA,
PPOTrainerForLLaMA,
plot_loss
)
def main():
# prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model)
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
total_train_batch_size = \
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
)
# Initialize our Trainer
ppo_trainer = PPOTrainerForLLaMA(
training_args=training_args,
finetuning_args=finetuning_args,
config=ppo_config,
model=model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args, keys=["loss", "reward"])
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

72
src/train_rm.py Normal file
View File

@@ -0,0 +1,72 @@
# coding=utf-8
# Implements parameter-efficient training of a reward model based on LLaMA.
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
PairwiseDataCollatorForLLaMA,
PairwiseTrainerForLLaMA,
plot_loss
)
def main():
# prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
training_args.remove_unused_columns = False # Important for pairwise dataset
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PairwiseTrainerForLLaMA(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
**trainer_kwargs
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

95
src/train_sft.py Normal file
View File

@@ -0,0 +1,95 @@
# coding=utf-8
# Implements several parameter-efficient supervised fine-tuning method for LLaMA.
# This code is inspired by
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
from utils import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DataCollatorForLLaMA,
Seq2SeqTrainerForLLaMA,
ComputeMetrics,
get_logits_processor,
plot_loss
)
def main():
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = Seq2SeqTrainerForLLaMA(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
)
# Keyword arguments for `model.generate`
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_length": data_args.max_source_length + data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results, tokenizer)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

15
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from .common import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data
)
from .data_collator import DataCollatorForLLaMA
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
from .ppo import PPOTrainerForLLaMA
from .config import ModelArguments
from .other import auto_configure_device_map, get_logits_processor, plot_loss

459
src/utils/common.py Normal file
View File

@@ -0,0 +1,459 @@
import os
import sys
import torch
import hashlib
from typing import List, Literal, Optional, Tuple
import transformers
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
import datasets
from datasets import Dataset, concatenate_datasets, load_dataset
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from trl import AutoModelForCausalLMWithValueHead
from .config import (
ModelArguments,
DataTrainingArguments,
FinetuningArguments
)
from .other import (
get_logger,
load_trainable_params,
load_valuehead_params,
print_trainable_params,
prepare_model_for_training,
IGNORE_INDEX,
FINETUNING_ARGS_NAME
)
check_min_version("4.29.1")
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
logger = get_logger(__name__)
def init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: bool
) -> PreTrainedModel:
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full":
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
lastest_checkpoint = None
if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training: # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if lastest_checkpoint is not None: # resume lora training
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True)
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
model = get_peft_model(model, lora_config)
return model
def load_pretrained(
model_args: ModelArguments,
finetuning_args: Optional[FinetuningArguments] = None,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
if (not is_trainable) and (model_args.checkpoint_dir is None):
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint
for checkpoint_dir in model_args.checkpoint_dir:
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)):
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.")
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1:
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
tokenizer = LlamaTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left"
)
tokenizer.pad_token_id = 0 # set as the <unk> token
# Quantization configurations (using bitsandbytes library).
config_kwargs = {}
if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization."
config_kwargs["load_in_8bit"] = True
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead).
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.half() # cast all params to float16 for inference
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model)
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
# To meet the compliance requirements of the transformers library
if model_args.quantization_bit is not None:
model._is_int8_training_enabled = True
print_trainable_params(model)
return model, tokenizer
def prepare_args(
stage: Literal["sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True in RM and PPO stages.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` for saving model predictions.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training for LLaMA.")
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
training_args.ddp_find_unused_parameters = False
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
# Log on each process the small summary:
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args
def prepare_data(
model_args: ModelArguments,
data_args: DataTrainingArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
elif dataset_attr.load_from == "script":
raw_datasets = load_dataset(
os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
cache_dir=model_args.cache_dir
)
elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) # support json, jsonl and csv
extension = dataset_attr.file_name.split(".")[-1]
if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset(
extension,
data_files=data_file,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None
)
else:
raise NotImplementedError
dataset = raw_datasets[data_args.split]
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
("query_column", "query"),
("response_column", "response"),
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0]
else:
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
def preprocess_data(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
data_args: DataTrainingArguments,
training_args: Seq2SeqTrainingArguments,
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
) -> Dataset:
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
def format_example(examples): # support question with a single answer or multiple answers
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += examples["query"][i]
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n" + prefix
if examples["history"][i]:
history = examples["history"][i]
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant: ".format(query)
yield prompt, answer
def preprocess_supervised_dataset(examples):
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_evaluation_dataset(examples):
# build inputs with format `X <s>` and labels with format `Y <s>`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # bos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id]
labels = target_ids + [tokenizer.bos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `X <s> Y1 </s>` and `X <s> Y2 </s>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs
def print_sft_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
def print_ppo_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
if stage == "sft":
if (not training_args.do_train) and training_args.predict_with_generate: # with generation
preprocess_function = preprocess_evaluation_dataset
else: # without generation
preprocess_function = preprocess_supervised_dataset
elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
preprocess_function = preprocess_evaluation_dataset
with training_args.main_process_first(desc="dataset map pre-processing"):
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
if stage == "sft":
print_sft_dataset_example(dataset[0])
elif stage == "rm":
print_pairwise_dataset_example(dataset[0])
elif stage == "ppo":
print_ppo_dataset_example(dataset[0])
return dataset

212
src/utils/config.py Normal file
View File

@@ -0,0 +1,212 @@
import os
import json
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
file_sha1: Optional[str] = None
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
)
use_fast_tokenizer: Optional[bool] = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
def __post_init__(self):
if self.checkpoint_dir is not None: # support merging lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
dataset: Optional[str] = field(
default="alpaca_zh",
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "The name of the folder containing datasets."}
)
split: Optional[str] = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."}
)
max_source_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total input sequence length after tokenization."}
)
max_target_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total output sequence length after tokenization."}
)
max_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
)
num_beams: Optional[int] = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
)
ignore_pad_token_for_loss: Optional[bool] = field(
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
)
source_prefix: Optional[str] = field(
default=None,
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
dev_ratio: Optional[float] = field(
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
)
def __post_init__(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
dataset_info = json.load(open(os.path.join(self.dataset_dir, "dataset_info.json"), "r"))
self.dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
)
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
@dataclass
class FinetuningArguments:
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
)
lora_dropout: Optional[float] = field(
default=0.1,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if isinstance(self.lora_target, str):
self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
if self.name_module_trainable == "mlp":
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
elif self.name_module_trainable == "qkv":
self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
"""Create an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@@ -0,0 +1,67 @@
import torch
from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from .other import IGNORE_INDEX
class DataCollatorForLLaMA(DataCollatorWithPadding):
r"""
Data collator for LLaMA. It is capable of dynamically padding for batched data.
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
ignore_pad_token_for_loss: Optional[bool] = False
):
super().__init__(tokenizer, padding=True)
self.model = model
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates attention masks for left-padded sequences.
"""
batch_size, seq_length = input_ids.size()
attention_mask = torch.ones((batch_size, seq_length), device=device)
for i, seq in enumerate(input_ids):
attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
attention_mask = attention_mask.bool()
return attention_mask
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We adopt left-padding in both training and evaluation.
"""
if isinstance(features[0]["input_ids"], torch.Tensor):
input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
else:
input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
if "labels" in features[0]:
if isinstance(features[0]["labels"], torch.Tensor):
labels = [feature["labels"].clone().detach().flip(0) for feature in features]
else:
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
input_ids = input_ids + labels # pad them to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1)
batch = {}
if "labels" in features[0]:
input_ids, labels = input_ids.split(len(features), dim=0)
labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
batch["labels"] = labels
batch["input_ids"] = input_ids
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
return batch

205
src/utils/other.py Normal file
View File

@@ -0,0 +1,205 @@
import os
import sys
import json
import torch
import logging
from typing import Dict, List, Optional
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TRAINER_STATE_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from peft.utils.other import WEIGHTS_NAME
IGNORE_INDEX = -100
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[logging.StreamHandler(sys.stdout)]
)
def get_logger(name: str) -> logging.Logger:
return logging.getLogger(name)
class AverageMeter:
r"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# Avoid runtime error in model.generate(do_sample=True).
# Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
output_embedding_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = ["norm"] # for LLaMA setting
) -> PreTrainedModel:
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if hasattr(model, output_embedding_layer_name):
output_embedding_layer = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x):
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
return model
def print_trainable_params(model: torch.nn.Module) -> None:
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
state_dict = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights."
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
assert os.path.exists(valuehead_file), f"Provided path ({checkpoint_dir}) does not contain the valuehead weights."
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
r"""
Configures device map for LLaMA.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8
"""
num_layers = 28
layers_per_gpu = 30 / num_gpus
device_map = {"model.embed_tokens": 0, "model.norm": 0, "lm_head": 0}
added_layers = 2
target_gpu = 0
for i in range(num_layers):
if added_layers >= layers_per_gpu:
target_gpu += 1
added_layers = 0
assert target_gpu < num_gpus
device_map[f"model.layers.{i}"] = target_gpu
added_layers += 1
return device_map
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
"""
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
last = smoothed_val
return smoothed
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
import matplotlib.pyplot as plt
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
for key in keys:
steps, metrics = [], []
for i in range(len(data["log_history"])):
if key in data["log_history"][i]:
steps.append(data["log_history"][i]["step"])
metrics.append(data["log_history"][i][key])
if len(metrics) == 0:
logger.warning(f"No metric {key} to plot.")
continue
plt.figure()
plt.plot(steps, metrics, alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), label="smoothed")
plt.title("training {} of {}".format(key, training_args.output_dir))
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key)))

51
src/utils/pairwise.py Normal file
View File

@@ -0,0 +1,51 @@
import torch
from typing import Dict, Sequence, Union
from .data_collator import DataCollatorForLLaMA
from .peft_trainer import PeftTrainer
from .other import get_logger
logger = get_logger(__name__)
class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
return super().__call__(features)
class PairwiseTrainerForLLaMA(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
def compute_loss(self, model, inputs, return_outputs=False):
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
outputs = {"r_accept": r_accept, "r_reject": r_reject}
return (loss, outputs) if return_outputs else loss

78
src/utils/peft_trainer.py Normal file
View File

@@ -0,0 +1,78 @@
import os
import torch
from typing import Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from peft.utils.other import WEIGHTS_NAME
from .config import FinetuningArguments
from .other import (
get_logger,
get_state_dict,
load_trainable_params,
load_valuehead_params,
FINETUNING_ARGS_NAME,
VALUE_HEAD_FILE_NAME
)
logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if hasattr(model, "pretrained_model"): # for models with valuehead
backbone_model = getattr(model, "pretrained_model")
else:
backbone_model = model
if hasattr(backbone_model, "peft_config"): # peft methods
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
else:
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
if hasattr(model, "peft_config"): # peft methods
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
else:
load_trainable_params(model, self.state.best_model_checkpoint)
if hasattr(model, "v_head"):
load_valuehead_params(model, self.state.best_model_checkpoint)

241
src/utils/ppo.py Normal file
View File

@@ -0,0 +1,241 @@
import os
import math
import torch
from tqdm import tqdm
from typing import Callable, Dict, List, Literal, Optional, Tuple
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TrainerState
from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
from .peft_trainer import PeftTrainer
from .config import FinetuningArguments
from .other import (
AverageMeter,
get_logger,
get_logits_processor
)
logger = get_logger(__name__)
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
if target == "reward": # save original head temporarily
valuehead_state_dict = model.v_head.state_dict()
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"])
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"])
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({
"summary.weight": getattr(model, "{}_head_weight".format(target)),
"summary.bias": getattr(model, "{}_head_bias".format(target))
})
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.finetuning_args = finetuning_args
self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
def ppo_train(self, max_target_length: int) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
len_dataloader = len(self.dataloader)
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
# Keyword arguments for `model.generate`
gen_kwargs = {
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter()
reward_meter = AverageMeter()
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
for _ in range(self.config.gradient_accumulation_steps):
batch = next(dataiter)
steps_trained += 1
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
# Get response from LLaMA
query_tensors: torch.Tensor = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
queries: List[torch.Tensor] = []
responses: List[torch.Tensor] = []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
if response_length < 2: # make response have at least 2 tokens
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
else:
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1]]
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
# Run PPO step
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"])
reward_meter.update(torch.tensor(rewards).sum().item(), n=len(rewards))
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = {
"loss": round(loss_meter.avg, 4),
"reward": round(reward_meter.avg, 4),
"learning_rate": stats["ppo/learning_rate"],
"epoch": round(step / num_steps_per_epoch, 2)
}
print(logs)
logs["step"] = step
self.state.log_history.append(logs)
loss_meter.reset()
reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
@torch.no_grad()
def generate(
self,
inputs: Dict[str, torch.Tensor],
length_sampler: Callable = None,
return_prompt: bool = True,
**generation_kwargs,
) -> torch.Tensor:
r"""
Generates model's responses given queries.
Subclass and override to inject custom behavior.
"""
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
unwrapped_model = self.accelerator.unwrap_model(self.model)
response = unwrapped_model.generate(**inputs, **generation_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,
model: AutoModelForCausalLMWithValueHead,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(model_inputs["input_ids"])
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(int(bs / fbs)):
input_kwargs = {k: v[i * fbs : (i + 1) * fbs] for k, v in model_inputs.items()}
input_ids: torch.Tensor = input_kwargs["input_ids"] # left-padded sequences
logits, _, values = model(**input_kwargs)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(input_ids)
for j in range(fbs):
start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item()
masks[j][start:] = 1
if len(masks[j][start:]) < 2:
raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")
all_logits.append(logits)
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1],
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
Subclass and override to inject custom behavior.
"""
if self.args.should_save:
self._save(output_dir)

96
src/utils/seq2seq.py Normal file
View File

@@ -0,0 +1,96 @@
import os
import json
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.tokenization_utils import PreTrainedTokenizer
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from .peft_trainer import PeftTrainer
from .other import get_logger, IGNORE_INDEX
logger = get_logger(__name__)
@dataclass
class ComputeMetrics:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
"""
tokenizer: PreTrainedTokenizer
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred, label in zip(preds, labels):
pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
if len(" ".join(hypothesis).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
else:
rouge = Rouge()
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
result = scores[0]
for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
return {k: float(np.mean(v)) for k, v in score_dict.items()}
class Seq2SeqTrainerForLLaMA(PeftTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def save_predictions(
self,
predict_results: PredictionOutput,
tokenizer: PreTrainedTokenizer
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(preds, labels):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))

129
src/web_demo.py Normal file
View File

@@ -0,0 +1,129 @@
# coding=utf-8
# Implements user interface in browser for LLaMA fine-tuned with PEFT.
# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint
import torch
import mdtex2html
import gradio as gr
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>"+line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
inputs = tokenizer([input], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = {
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"num_beams": 1,
"max_length": max_length,
"repetition_penalty": 1.0
}
with torch.no_grad():
generation_output = model.generate(**inputs, **gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
history = history + [(input, response)]
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], []
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM-Efficient-Tuning</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True)