remove rlhf support for chatglm2&3

Former-commit-id: bcbb5b71961b89719bffb0d202c431c82e6067cc
This commit is contained in:
hiyouga
2024-07-02 23:03:17 +08:00
parent 579997688f
commit ca548af2a2
2 changed files with 2 additions and 18 deletions

View File

@@ -150,14 +150,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.callback_handler = CallbackHandler(
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
self.amp_context = torch.autocast(self.current_device.type)
warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":
@@ -403,9 +399,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
if self.is_chatglm_model: # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.float().detach() # use fp32 type
@@ -443,9 +436,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
with self.amp_context: # support bf16
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
if self.is_chatglm_model:
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]