add cal_lr.py

Former-commit-id: cea2ba17efc47917e63437a376f220864f7f90dd
This commit is contained in:
hiyouga
2023-11-14 20:58:37 +08:00
parent c9a4551012
commit 75dd1f0f7e
4 changed files with 67 additions and 6 deletions

View File

@@ -12,7 +12,7 @@ from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel
def calculate(
def calculate_flops(
model_name_or_path: str,
batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256,
@@ -41,4 +41,4 @@ def calculate(
if __name__ == "__main__":
fire.Fire(calculate)
fire.Fire(calculate_flops)