format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
import torch
from transformers import DataCollatorForSeq2Seq
@@ -20,7 +21,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
@@ -34,10 +35,12 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
concatenated_features.append(
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len),
}
)
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(