format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
@@ -1,27 +1,28 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
||||
from ...extras.callbacks import LogCallback, FixValueHeadModelCallback
|
||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from .utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from ...hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -40,7 +41,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["TrainerCallback"],
|
||||
reward_model: "AutoModelForCausalLMWithValueHead",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
|
||||
@@ -52,7 +53,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
**generating_args.to_dict()
|
||||
**generating_args.to_dict(),
|
||||
)
|
||||
|
||||
self.state = TrainerState()
|
||||
@@ -71,7 +72,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if not (
|
||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
): # quantized models are already set on the correct device
|
||||
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
@@ -111,9 +112,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
logger.info(" Num examples = {}".format(num_examples))
|
||||
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
||||
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
||||
total_train_batch_size
|
||||
))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
||||
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
||||
logger.info(" Total training steps = {}".format(max_steps))
|
||||
@@ -138,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
queries, responses, rewards = [], [], []
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(
|
||||
batch[idx : idx + self.config.mini_batch_size]
|
||||
)
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||
queries.extend(mini_batch_queries)
|
||||
responses.extend(mini_batch_responses)
|
||||
@@ -154,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
@@ -163,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
self.log_stats(stats, batch, rewards)
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
||||
logs = dict(
|
||||
loss=round(loss_meter.avg, 4),
|
||||
reward=round(reward_meter.avg, 4),
|
||||
learning_rate=stats["ppo/learning_rate"],
|
||||
epoch=round(step / steps_in_epoch, 2)
|
||||
epoch=round(step / steps_in_epoch, 2),
|
||||
)
|
||||
tqdm.write(str(logs))
|
||||
logs["step"] = step
|
||||
@@ -183,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
loss_meter.reset()
|
||||
reward_meter.reset()
|
||||
|
||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(os.path.join(
|
||||
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
|
||||
))
|
||||
if (step + 1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(
|
||||
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||
)
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
@@ -207,35 +212,33 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
logits_processor=get_logits_processor(),
|
||||
**batch
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
|
||||
if self.model_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query = batch["input_ids"].detach().cpu()
|
||||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
response_length = 1 # allow empty response
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
||||
return queries, responses
|
||||
|
||||
@@ -244,7 +247,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self,
|
||||
queries: List[torch.Tensor],
|
||||
responses: List[torch.Tensor],
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead",
|
||||
) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
@@ -264,17 +267,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
@@ -289,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: Optional[bool] = False,
|
||||
response_masks: Optional[torch.Tensor] = None
|
||||
response_masks: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
@@ -312,7 +315,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
@@ -325,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
for j in range(len(query_batch)):
|
||||
start = len(query_batch[j]) - 1
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
start += attention_mask[j, :].nonzero()[0].item()
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
response_masks_batch = torch.cat(
|
||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
||||
)[1:]
|
||||
response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
|
||||
|
||||
masks[j, :start] = 0
|
||||
masks[j, end:] = 0
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras.packages import is_requests_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
@@ -21,16 +23,18 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||
})
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict(
|
||||
{
|
||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||
|
||||
@@ -1,23 +1,26 @@
|
||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
|
||||
import math
|
||||
from trl import PPOConfig
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from torch.optim import AdamW
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers.optimization import get_scheduler
|
||||
from trl import PPOConfig
|
||||
|
||||
from ...data import get_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model_and_tokenizer
|
||||
from ...train.utils import create_ref_model, create_reward_model
|
||||
from ...train.ppo.trainer import CustomPPOTrainer
|
||||
from ...train.utils import create_ref_model, create_reward_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from ...hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
def run_ppo(
|
||||
@@ -26,12 +29,14 @@ def run_ppo(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
model, tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
||||
)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# Create reference model and reward model
|
||||
@@ -55,7 +60,7 @@ def run_ppo(
|
||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
@@ -70,7 +75,7 @@ def run_ppo(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
@@ -88,7 +93,7 @@ def run_ppo(
|
||||
dataset=dataset,
|
||||
data_collator=data_collator,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -97,6 +102,6 @@ def run_ppo(
|
||||
ppo_trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
Reference in New Issue
Block a user