mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[deps] adapt to transformers v5 (#10147)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -63,13 +64,21 @@ def test_pairwise_data(num_samples: int):
|
||||
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
|
||||
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
|
||||
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
|
||||
|
||||
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
|
||||
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True))
|
||||
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
||||
ref_chosen_prompt_ids = ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)
|
||||
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
|
||||
rejected_prompt_len = len(
|
||||
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
||||
)
|
||||
ref_rejected_prompt_ids = ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_chosen_input_ids = ref_chosen_input_ids["input_ids"]
|
||||
ref_rejected_input_ids = ref_rejected_input_ids["input_ids"]
|
||||
ref_chosen_prompt_ids = ref_chosen_prompt_ids["input_ids"]
|
||||
ref_rejected_prompt_ids = ref_rejected_prompt_ids["input_ids"]
|
||||
|
||||
chosen_prompt_len = len(ref_chosen_prompt_ids)
|
||||
rejected_prompt_len = len(ref_rejected_prompt_ids)
|
||||
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
||||
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
|
||||
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
|
||||
assert train_dataset["chosen_labels"][index] == ref_chosen_labels
|
||||
|
||||
Reference in New Issue
Block a user