Former-commit-id: ff9a3f73961a362d0ddc22079f80a85465fffda8
This commit is contained in:
hiyouga
2024-04-01 22:53:52 +08:00
parent 85726c91ce
commit 1dc963caa6
4 changed files with 23 additions and 15 deletions

View File

@@ -95,7 +95,10 @@ class CustomDPOTrainer(DPOTrainer):
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
all_logps = self.get_batch_logps(

View File

@@ -73,7 +73,7 @@ class CustomORPOTrainer(DPOTrainer):
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps(

View File

@@ -353,7 +353,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)