support FlashAttention2
Former-commit-id: 23e56c5554b948d4f08ad87849b261eafd2c7890
This commit is contained in:
@@ -206,9 +206,6 @@ def get_template_and_fix_tokenizer(
|
||||
name: str,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Template:
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
@@ -217,6 +214,11 @@ def get_template_and_fix_tokenizer(
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=template.stop_words),
|
||||
replace_additional_special_tokens=False
|
||||
|
||||
Reference in New Issue
Block a user