mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-18 23:13:08 +00:00
alter rewards data type
Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers import DataCollatorWithPadding, BatchEncoding
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
@@ -34,7 +34,7 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
attention_mask = attention_mask.bool()
|
||||
return attention_mask
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
@@ -64,4 +64,4 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
batch["input_ids"] = input_ids
|
||||
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
|
||||
|
||||
return batch
|
||||
return BatchEncoding(batch)
|
||||
|
||||
Reference in New Issue
Block a user