mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[deps] adapt to transformers v5 (#10147)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -19,6 +19,7 @@ import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -55,8 +56,13 @@ def test_unsupervised_data(num_samples: int):
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
ref_labels = ref_ids[len(ref_input_ids) :]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
ref_labels = ref_input_ids[len(ref_prompt_ids) :]
|
||||
assert train_dataset["input_ids"][index] == ref_prompt_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
||||
|
||||
Reference in New Issue
Block a user