update get template
Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
This commit is contained in:
@@ -18,7 +18,7 @@ from collections import defaultdict
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
|
||||
from llamafactory.data import get_dataset
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
@@ -48,7 +48,8 @@ def length_cdf(
|
||||
)
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
|
||||
total_num = len(trainset)
|
||||
length_dict = defaultdict(int)
|
||||
for sample in tqdm(trainset["input_ids"]):
|
||||
|
||||
Reference in New Issue
Block a user