fix tests

Former-commit-id: 23f97bd437424ef43b2b84743d56acc5d1ca70d5
This commit is contained in:
hiyouga
2024-01-20 19:58:04 +08:00
parent 80637fc06d
commit 1750218057
12 changed files with 80 additions and 65 deletions

View File

@@ -11,9 +11,10 @@ from typing import Optional
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.model import get_train_args, load_model_and_tokenizer
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@@ -26,7 +27,7 @@ def calculate_lr(
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "../data"
dataset_dir: Optional[str] = "data"
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft",
@@ -37,9 +38,8 @@ def calculate_lr(
cutoff_len=cutoff_len,
output_dir="dummy_dir"
))
trainset = get_dataset(model_args, data_args)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft")
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
dataloader = DataLoader(
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True