update scripts

Former-commit-id: 1c07648c4bb4bb0c46bc0240547b46bd2835dce1
This commit is contained in:
hiyouga
2024-05-04 23:05:17 +08:00
parent 9b187b274c
commit f9aa74715a
2 changed files with 35 additions and 3 deletions

View File

@@ -4,6 +4,7 @@
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import math
from typing import Literal
import fire
import torch
@@ -24,7 +25,7 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: str = "sft",
stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",