release v0.1.0

Former-commit-id: 63c8d3a17cb18f0d8a8e37bfa147daf5bdd28ea9
This commit is contained in:
hiyouga
2023-07-18 00:18:25 +08:00
parent c08ff734a7
commit eac7f97337
30 changed files with 1513 additions and 309 deletions

View File

@@ -28,7 +28,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl==0.4.4", "To fix: pip install trl==0.4.4")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer(

View File

@@ -25,7 +25,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(
self,
training_args: Seq2SeqTrainingArguments,
@@ -46,12 +45,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)
len_dataloader = len(self.dataloader)
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
max_steps = math.ceil(num_train_epochs * len_dataloader)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
@@ -62,9 +62,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
@@ -77,7 +77,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
@@ -87,59 +87,45 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
batch = next(dataiter)
steps_trained += 1
for _ in range(self.config.gradient_accumulation_steps):
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
batch = next(dataiter)
steps_trained += 1
# Get responses
query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
queries, responses = [], []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Get response from model
query_tensors: torch.Tensor = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
queries: List[torch.Tensor] = []
responses: List[torch.Tensor] = []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
if response_length < 2: # make response have at least 2 tokens
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
else:
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default")
# Run PPO step
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
# Run PPO step
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = {
"loss": round(loss_meter.avg, 4),
"reward": round(reward_meter.avg, 4),
"learning_rate": stats["ppo/learning_rate"],
"epoch": round(step / num_steps_per_epoch, 2)
}
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
)
print(logs)
logs["step"] = step
self.state.log_history.append(logs)
@@ -150,10 +136,14 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
if self.control.should_training_stop:
if self.control.should_epoch_stop or self.control.should_training_stop:
break
@torch.inference_mode()
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
@torch.no_grad()
def generate(
self,
inputs: Dict[str, torch.Tensor],

View File

@@ -4,7 +4,8 @@
import math
from trl import PPOConfig
from torch.optim import AdamW
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
@@ -19,7 +20,8 @@ def run_ppo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
@@ -30,7 +32,7 @@ def run_ppo(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
@@ -50,7 +52,7 @@ def run_ppo(
ppo_trainer = PPOPeftTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[LogCallback()],
callbacks=callbacks,
config=ppo_config,
model=model,
ref_model=None,

View File

@@ -2,7 +2,8 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from transformers import Seq2SeqTrainingArguments
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
@@ -18,7 +19,8 @@ def run_rm(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
@@ -44,7 +46,7 @@ def run_rm(
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
callbacks=callbacks,
compute_metrics=compute_accuracy,
**trainer_kwargs
)