mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[deps] adapt to transformers v5 (#10147)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -52,7 +53,12 @@ def test_feedback_data(num_samples: int):
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
prompt_len = len(ref_prompt_ids)
|
||||
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
||||
|
||||
Reference in New Issue
Block a user