update scripts
Former-commit-id: 1c07648c4bb4bb0c46bc0240547b46bd2835dce1
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -24,7 +25,7 @@ BASE_BS = 4_000_000 # from llama paper
|
||||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
stage: str = "sft",
|
||||
stage: Literal["pt", "sft"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
|
||||
Reference in New Issue
Block a user