modity code structure

Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
hiyouga
2023-07-15 16:54:28 +08:00
parent fa06b168ab
commit 6261fb362a
57 changed files with 1999 additions and 1816 deletions

View File

@@ -0,0 +1,19 @@
import torch
from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
return super().__call__(features)