refactor data preprocessing, fix mllm rlhf
Former-commit-id: 53ff2dd24f9121ea30c95063bb72e49a9b31e980
This commit is contained in:
@@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
with torch.no_grad():
|
||||
kl_logits = model(
|
||||
input_ids=batch["kl_input_ids"],
|
||||
attention_mask=batch["kl_attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
|
||||
if "pixel_values" in batch:
|
||||
kl_model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
target_logits = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
if "kl_token_type_ids" in batch:
|
||||
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
|
||||
|
||||
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "token_type_ids" in batch:
|
||||
model_inputs["token_type_ids"] = batch["token_type_ids"]
|
||||
|
||||
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
target_logps = self.get_batch_logps(
|
||||
logits=target_logits,
|
||||
|
||||
Reference in New Issue
Block a user