fix scripts
Former-commit-id: f94f55d20283298cb7d90d0573992a62df414a8f
This commit is contained in:
@@ -22,9 +22,9 @@ import fire
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
@@ -71,7 +71,7 @@ def calculate_lr(
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
@@ -81,14 +81,13 @@ def calculate_lr(
|
||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||
total_tokens += torch.numel(batch["labels"])
|
||||
|
||||
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
||||
valid_ratio = valid_tokens / total_tokens
|
||||
batch_valid_len = batch_max_len * valid_ratio
|
||||
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
||||
token_batch_size = cutoff_len * batch_size * valid_ratio
|
||||
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
|
||||
lr = lr / 6.0 if is_mistral_or_gemma else lr
|
||||
print(
|
||||
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||
lr, valid_ratio * 100, batch_valid_len
|
||||
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
|
||||
lr, valid_ratio * 100, token_batch_size
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user