improve rlhf
Former-commit-id: e441780e3db256ca09a442ea9254e7ce16898a07
This commit is contained in:
@@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.trainer import DEFAULT_CALLBACKS
|
||||
from transformers.trainer_callback import CallbackHandler
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
@@ -105,6 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||
]
|
||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
|
||||
ppo_config.log_with = None
|
||||
|
||||
# Create optimizer and scheduler
|
||||
if training_args.max_steps > 0:
|
||||
@@ -143,6 +146,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.control = TrainerControl()
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
@@ -339,11 +343,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generate_output: "torch.Tensor" = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
if self.model_args.upcast_layernorm:
|
||||
@@ -354,12 +358,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
if len(response_indexes) == 0: # allow empty response
|
||||
response_length = 1
|
||||
elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token
|
||||
response_length = response_indexes[-1].item() + 2
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
response_length = response_indexes[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
@@ -382,7 +388,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
@@ -392,7 +398,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
reward_model = self.reward_model
|
||||
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
_, _, values = reward_model(**batch, return_dict=True, use_cache=False)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
@@ -400,13 +406,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if self.is_chatglm_model: # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
return rewards
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return rewards.to(torch.float32).detach().cpu() # use fp32 type
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
@@ -440,7 +441,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
with self.amp_context: # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
|
||||
|
||||
if self.is_chatglm_model:
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
Reference in New Issue
Block a user