fix scripts

Former-commit-id: f94f55d20283298cb7d90d0573992a62df414a8f
This commit is contained in:
hiyouga
2024-12-05 03:47:28 +00:00
parent ff3e40e4a5
commit 86e4fab0d5
4 changed files with 32 additions and 24 deletions

View File

@@ -22,9 +22,9 @@ import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from transformers import DataCollatorForLanguageModeling
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
@@ -71,7 +71,7 @@ def calculate_lr(
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError(f"Stage does not supported: {stage}.")
@@ -81,14 +81,13 @@ def calculate_lr(
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
batch_max_len = cutoff_len * batch_size # max tokens in a batch
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
token_batch_size = cutoff_len * batch_size * valid_ratio
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
lr, valid_ratio * 100, token_batch_size
)
)