support FlashAttention2

Former-commit-id: 23e56c5554b948d4f08ad87849b261eafd2c7890
This commit is contained in:
hiyouga
2023-09-10 20:43:56 +08:00
parent b481ad58e6
commit a402161631
9 changed files with 875 additions and 115 deletions

View File

@@ -206,9 +206,6 @@ def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
@@ -217,6 +214,11 @@ def get_template_and_fix_tokenizer(
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None:
return None
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False