refactor data preprocessing, fix mllm rlhf

Former-commit-id: 53ff2dd24f9121ea30c95063bb72e49a9b31e980
This commit is contained in:
hiyouga
2024-05-24 04:08:25 +08:00
parent 1078611259
commit bf59383783
15 changed files with 572 additions and 464 deletions

View File

@@ -4,7 +4,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import BatchEncoding, Trainer
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
@@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps = self.get_batch_logps(
logits=all_logits,