add docstrings, refactor logger

Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 93d4570a59
commit 7f71276ad8
30 changed files with 334 additions and 57 deletions

View File

@@ -100,7 +100,7 @@ def compute_device_flops() -> float:
raise NotImplementedError("Device not supported: {}.".format(device_name))
def compute_mfu(
def calculate_mfu(
model_name_or_path: str,
batch_size: int,
seq_length: int,
@@ -111,7 +111,7 @@ def compute_mfu(
liger_kernel: bool = False,
) -> float:
r"""
Computes MFU for given model and hyper-params.
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {
@@ -146,4 +146,4 @@ def compute_mfu(
if __name__ == "__main__":
fire.Fire(compute_mfu)
fire.Fire(calculate_mfu)