fix baichuan templates
Former-commit-id: f48a49e835b32f3991cfad8874c7b9c78953809f
This commit is contained in:
@@ -20,6 +20,7 @@ class Template:
|
||||
sep: List[Union[str, Dict[str, str]]]
|
||||
stop_words: List[str]
|
||||
use_history: bool
|
||||
efficient_eos: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
@@ -74,19 +75,19 @@ class Template:
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if (
|
||||
tokenizer.bos_token_id is not None
|
||||
and getattr(tokenizer, "add_bos_token", True)
|
||||
): # baichuan-13b has no bos token
|
||||
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
||||
bos_ids = [tokenizer.bos_token_id]
|
||||
else:
|
||||
bos_ids = [] # bos token is optional
|
||||
else: # baichuan, qwen and gpt2 models has no bos token
|
||||
bos_ids = []
|
||||
|
||||
if tokenizer.eos_token_id is not None:
|
||||
eos_ids = [tokenizer.eos_token_id]
|
||||
else:
|
||||
if tokenizer.eos_token_id is None:
|
||||
raise ValueError("EOS token is required.")
|
||||
|
||||
if self.efficient_eos: # used in baichuan, qwen and gpt2 models
|
||||
eos_ids = []
|
||||
else:
|
||||
eos_ids = [tokenizer.eos_token_id]
|
||||
|
||||
return bos_ids, eos_ids
|
||||
|
||||
def _encode(
|
||||
@@ -186,7 +187,8 @@ def register_template(
|
||||
system: str,
|
||||
sep: List[Union[str, Dict[str, str]]],
|
||||
stop_words: Optional[List[str]] = [],
|
||||
use_history: Optional[bool] = True
|
||||
use_history: Optional[bool] = True,
|
||||
efficient_eos: Optional[bool] = False
|
||||
) -> None:
|
||||
template_class = Llama2Template if "llama2" in name else Template
|
||||
templates[name] = template_class(
|
||||
@@ -195,7 +197,8 @@ def register_template(
|
||||
system=system,
|
||||
sep=sep,
|
||||
stop_words=stop_words,
|
||||
use_history=use_history
|
||||
use_history=use_history,
|
||||
efficient_eos=efficient_eos
|
||||
)
|
||||
|
||||
|
||||
@@ -206,15 +209,6 @@ def get_template_and_fix_tokenizer(
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
|
||||
additional_special_tokens = template.stop_words
|
||||
if len(template.stop_words): # inplace method
|
||||
if tokenizer.eos_token_id is not None:
|
||||
additional_special_tokens.append(tokenizer.eos_token)
|
||||
|
||||
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
|
||||
additional_special_tokens.pop(0)
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
@@ -227,7 +221,7 @@ def get_template_and_fix_tokenizer(
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=additional_special_tokens),
|
||||
dict(additional_special_tokens=template.stop_words),
|
||||
replace_additional_special_tokens=False
|
||||
)
|
||||
return template
|
||||
@@ -466,18 +460,18 @@ register_template(
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<eoa>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"</s>", # internlm cannot replace eos token
|
||||
"<eoa>"
|
||||
]
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
Used for training and inference of the fine-tuned models.
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan",
|
||||
@@ -487,39 +481,17 @@ register_template(
|
||||
prompt=[
|
||||
{"token": "<reserved_102>"}, # user token
|
||||
"{{query}}",
|
||||
{"token": "<reserved_103>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
Used for inference of the original model.
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan_eval",
|
||||
prefix=[
|
||||
"{{system}}",
|
||||
{"token": "<reserved_102>"} # user token
|
||||
],
|
||||
prompt=[
|
||||
"{{query}}",
|
||||
{"token": "<reserved_103>"} # assistant token
|
||||
{"token": "<reserved_103>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
stop_words=[
|
||||
"<reserved_102>" # user token
|
||||
]
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
||||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
||||
Used for training and inference of the fine-tuned models.
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
@@ -529,33 +501,11 @@ register_template(
|
||||
prompt=[
|
||||
{"token": "<reserved_106>"}, # user token
|
||||
"{{query}}",
|
||||
{"token": "<reserved_107>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
||||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
||||
Used for inference of the original model.
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan2_eval",
|
||||
prefix=[
|
||||
"{{system}}",
|
||||
{"token": "<reserved_106>"} # user token
|
||||
],
|
||||
prompt=[
|
||||
"{{query}}",
|
||||
{"token": "<reserved_107>"} # assistant token
|
||||
{"token": "<reserved_107>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
stop_words=[
|
||||
"<reserved_106>" # user token
|
||||
]
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
@@ -568,7 +518,6 @@ register_template(
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n{{system}}",
|
||||
{"token": "<|end|>"}
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
@@ -579,11 +528,13 @@ register_template(
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<|end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|end|>"
|
||||
]
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
@@ -594,8 +545,7 @@ register_template(
|
||||
name="chatml",
|
||||
prefix=[
|
||||
{"token": "<|im_start|>"},
|
||||
"system\n{{system}}",
|
||||
{"token": "<|im_end|>"}
|
||||
"system\n{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|im_start|>"},
|
||||
@@ -607,11 +557,13 @@ register_template(
|
||||
],
|
||||
system="You are a helpful assistant.",
|
||||
sep=[
|
||||
{"token": "<|im_end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
]
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user