mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
[deps] adapt to transformers v5 (#10147)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -59,7 +60,16 @@ def test_supervised_single_turn(num_samples: int):
|
||||
{"role": "assistant", "content": original_data["output"][index]},
|
||||
]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
prompt_len = len(ref_prompt_ids)
|
||||
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_label_ids
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@@ -73,6 +83,10 @@ def test_supervised_multi_turn(num_samples: int):
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
|
||||
# cannot test the label ids in multi-turn case
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@@ -86,9 +100,12 @@ def test_supervised_train_on_prompt(num_samples: int):
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
assert train_dataset["input_ids"][index] == ref_ids
|
||||
assert train_dataset["labels"][index] == ref_ids
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@@ -103,7 +120,13 @@ def test_supervised_mask_history(num_samples: int):
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
prompt_len = len(ref_prompt_ids)
|
||||
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_label_ids
|
||||
|
||||
Reference in New Issue
Block a user