fix resize vocab at inference #3022

Former-commit-id: c243720b89eec0af2872fa3c7980a0026d893f4d
This commit is contained in:
hiyouga
2024-04-03 18:14:24 +08:00
parent f6530222f7
commit 1348f7d860
9 changed files with 31 additions and 40 deletions

View File

@@ -15,7 +15,7 @@ from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
from llmtuner.model import load_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@@ -32,7 +32,7 @@ def calculate_lr(
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage=stage,
model_name_or_path=model_name_or_path,
@@ -44,8 +44,8 @@ def calculate_lr(
overwrite_cache=True,
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
tokenizer = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":