improve rlhf

Former-commit-id: e441780e3db256ca09a442ea9254e7ce16898a07
This commit is contained in:
hiyouga
2024-07-02 22:23:08 +08:00
parent f0b01803ea
commit e6ba7ef3e6
8 changed files with 55 additions and 114 deletions

View File

@@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
from transformers.trainer import DEFAULT_CALLBACKS
from transformers.trainer_callback import CallbackHandler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
@@ -105,6 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
ppo_config.log_with = None
# Create optimizer and scheduler
if training_args.max_steps > 0:
@@ -143,6 +146,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl()
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
@@ -339,11 +343,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
generate_output: torch.Tensor = unwrapped_model.generate(
generate_output: "torch.Tensor" = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
@@ -354,12 +358,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
queries, responses = [], []
for i in range(len(query)):
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
if len(response_indexes) == 0: # allow empty response
response_length = 1
elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token
response_length = response_indexes[-1].item() + 2
else:
response_length = response_index[-1].item() + 1
response_length = response_indexes[-1].item() + 1
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
@@ -382,7 +388,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return get_rewards_from_server(self.reward_model, messages)
batch = self.prepare_model_inputs(queries, responses)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora":
@@ -392,7 +398,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
_, _, values = reward_model(**batch, return_dict=True, use_cache=False)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
@@ -400,13 +406,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_chatglm_model: # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
return rewards
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.to(torch.float32).detach().cpu() # use fp32 type
@PPODecorators.empty_device_cache()
def batched_forward_pass(
@@ -440,7 +441,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
attention_mask = input_kwargs["attention_mask"]
with self.amp_context: # support bf16
logits, _, values = model(**input_kwargs)
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
if self.is_chatglm_model:
values = torch.transpose(values, 0, 1)