add cal_lr.py
Former-commit-id: cea2ba17efc47917e63437a376f220864f7f90dd
This commit is contained in:
@@ -12,7 +12,7 @@ from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||
from llmtuner import ChatModel
|
||||
|
||||
|
||||
def calculate(
|
||||
def calculate_flops(
|
||||
model_name_or_path: str,
|
||||
batch_size: Optional[int] = 1,
|
||||
seq_length: Optional[int] = 256,
|
||||
@@ -41,4 +41,4 @@ def calculate(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(calculate)
|
||||
fire.Fire(calculate_flops)
|
||||
|
||||
Reference in New Issue
Block a user