Files
LLaMA-Factory/src/llmtuner/tuner/rm/trainer.py
hiyouga 6261fb362a modity code structure
Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
2023-07-15 16:54:28 +08:00

39 lines
1.4 KiB
Python

import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from llmtuner.tuner.core.trainer import PeftTrainer
class PairwisePeftTrainer(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self,
model: PreTrainedModel,
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss