modify some style

Former-commit-id: 053062abc007014a7fde95c5ae9f4d859893d8ad
This commit is contained in:
BUAADreamer
2024-04-25 22:04:09 +08:00
parent 8b2a735c14
commit 514ffafc12
4 changed files with 50 additions and 8 deletions

View File

@@ -324,10 +324,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .formatter import SLOTS, Formatter
logger = get_logger(__name__)
@@ -103,9 +104,7 @@ class Template:
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
elements: List[Union[str, Dict[str, str]]],
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts elements to token ids.
@@ -391,6 +390,7 @@ _register_template(
),
)
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
@@ -403,6 +403,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="atom",
format_user=StringFormatter(
@@ -411,18 +412,21 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
)
_register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
_register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
@@ -431,11 +435,13 @@ _register_template(
force_system=True,
)
_register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
@@ -447,6 +453,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@@ -456,6 +463,7 @@ _register_template(
force_system=True,
)
_register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@@ -470,6 +478,7 @@ _register_template(
force_system=True,
)
_register_template(
name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@@ -489,6 +498,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -499,6 +509,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -510,12 +521,14 @@ _register_template(
replace_eos=True,
)
_register_template(
name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="cohere",
format_user=StringFormatter(
@@ -530,6 +543,7 @@ _register_template(
force_system=True,
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@@ -537,6 +551,7 @@ _register_template(
force_system=True,
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -562,6 +577,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@@ -569,6 +585,7 @@ _register_template(
force_system=True,
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
@@ -584,6 +601,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
@@ -591,12 +609,14 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -604,12 +624,14 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
@@ -622,6 +644,7 @@ _register_template(
force_system=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
@@ -630,6 +653,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -646,6 +670,7 @@ _register_template(
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@@ -662,6 +687,7 @@ _register_template(
),
)
_register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@@ -669,6 +695,7 @@ _register_template(
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
_register_template(
name="llama3",
format_user=StringFormatter(
@@ -695,6 +722,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
@@ -702,6 +730,7 @@ _register_template(
force_system=True,
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
@@ -710,6 +739,7 @@ _register_template(
force_system=True,
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
@@ -718,6 +748,7 @@ _register_template(
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@@ -725,6 +756,7 @@ _register_template(
force_system=True,
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@@ -736,6 +768,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -747,6 +780,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
@@ -754,6 +788,7 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
@@ -764,6 +799,7 @@ _register_template(
force_system=True,
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
@@ -773,6 +809,7 @@ _register_template(
),
)
_register_template(
name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
@@ -783,11 +820,13 @@ _register_template(
),
)
_register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
@@ -807,6 +846,7 @@ _register_template(
stop_words=["<|End|>"],
)
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -815,6 +855,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@@ -823,6 +864,7 @@ _register_template(
replace_eos=True,
)
_register_template(
name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
@@ -831,12 +873,14 @@ _register_template(
default_system="You are a friendly chatbot who always responds in the style of a pirate",
)
_register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT: "]),

View File

@@ -43,7 +43,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id),
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Override the decoding parameters of Seq2SeqTrainer

View File

@@ -19,6 +19,7 @@ from .sft import run_sft
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)