support qwen2vl vllm infer
Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
This commit is contained in:
@@ -24,7 +24,7 @@ from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq
|
||||
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
@@ -71,7 +71,9 @@ def calculate_lr(
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user