format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,27 +1,28 @@
import math
import os
import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import torch
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
@@ -40,7 +41,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs
**kwargs,
):
PPOTrainer.__init__(self, **kwargs)
@@ -52,7 +53,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
**generating_args.to_dict(),
)
self.state = TrainerState()
@@ -71,7 +72,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
@@ -111,9 +112,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
))
logger.info(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
)
)
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps))
@@ -138,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
self.tokenizer.padding_side = "right" # change padding side
self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
@@ -154,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Run PPO step
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
@@ -163,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)
except:
except Exception:
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict(
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / steps_in_epoch, 2)
epoch=round(step / steps_in_epoch, 2),
)
tqdm.write(str(logs))
logs["step"] = step
@@ -183,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
loss_meter.reset()
reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
))
if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@@ -207,35 +212,33 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
logits_processor=get_logits_processor(),
**batch
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query = batch["input_ids"].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
response_length = 1 # allow empty response
else:
response_length = response_index[-1].item() + 1
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
return queries, responses
@@ -244,7 +247,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead"
unwrapped_model: "AutoModelForCausalLMWithValueHead",
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
@@ -264,17 +267,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
@@ -289,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
response_masks: Optional[torch.Tensor] = None,
):
r"""
Calculates model outputs in multiple batches.
@@ -312,7 +315,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@@ -325,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
masks[j, :start] = 0
masks[j, end:] = 0

View File

@@ -1,9 +1,11 @@
import json
import torch
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch
from ...extras.packages import is_requests_available
if TYPE_CHECKING:
from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
@@ -21,16 +23,18 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
if target == "reward": # save default head temporarily
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
})
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict(
{
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
}
)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:

View File

@@ -1,23 +1,26 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from trl import PPOConfig
from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model_and_tokenizer
from ...train.utils import create_ref_model, create_reward_model
from ...train.ppo.trainer import CustomPPOTrainer
from ...train.utils import create_ref_model, create_reward_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_ppo(
@@ -26,12 +29,14 @@ def run_ppo(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
callbacks: Optional[List["TrainerCallback"]] = None,
):
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
model, tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, training_args.do_train, add_valuehead=True
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
@@ -55,7 +60,7 @@ def run_ppo(
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
accelerator_kwargs={"step_scheduler_with_optimizer": False},
)
# Create optimizer and scheduler
@@ -70,7 +75,7 @@ def run_ppo(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps
num_training_steps=num_training_steps,
)
# Initialize our Trainer
@@ -88,7 +93,7 @@ def run_ppo(
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler
lr_scheduler=lr_scheduler,
)
# Training
@@ -97,6 +102,6 @@ def run_ppo(
ppo_trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
ppo_trainer.save_state() # must be called after save_model to have a folder
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])