disentangle model from tuner and rename modules
Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
This commit is contained in:
@@ -7,12 +7,13 @@ import fire
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.data import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.model import get_train_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||
@@ -22,14 +23,16 @@ BASE_BS = 4_000_000 # from llama paper
|
||||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
dataset: str,
|
||||
cutoff_len: int, # i.e. maximum input length during training
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
is_mistral: bool # mistral model uses a smaller learning rate
|
||||
cutoff_len: int, # i.e. maximum input length during training
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||
dataset_dir: Optional[str] = "data"
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template="default",
|
||||
cutoff_len=cutoff_len,
|
||||
output_dir="dummy_dir"
|
||||
|
||||
Reference in New Issue
Block a user