[model] support yarn (#6693)

Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
This commit is contained in:
hoshi-hiyouga
2025-01-18 13:56:09 +08:00
committed by GitHub
parent e4046bdd1f
commit 87d685b59f
11 changed files with 84 additions and 64 deletions

View File

@@ -33,7 +33,7 @@ def preprocess_pretrain_dataset(
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
if getattr(tokenizer, "add_bos_token", False):
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
@@ -47,7 +47,7 @@ def preprocess_pretrain_dataset(
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
if getattr(tokenizer, "add_bos_token", False):
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id