add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -100,7 +100,7 @@ def compute_device_flops() -> float:
|
||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||
|
||||
|
||||
def compute_mfu(
|
||||
def calculate_mfu(
|
||||
model_name_or_path: str,
|
||||
batch_size: int,
|
||||
seq_length: int,
|
||||
@@ -111,7 +111,7 @@ def compute_mfu(
|
||||
liger_kernel: bool = False,
|
||||
) -> float:
|
||||
r"""
|
||||
Computes MFU for given model and hyper-params.
|
||||
Calculates MFU for given model and hyper-params.
|
||||
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
||||
"""
|
||||
args = {
|
||||
@@ -146,4 +146,4 @@ def compute_mfu(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(compute_mfu)
|
||||
fire.Fire(calculate_mfu)
|
||||
|
||||
Reference in New Issue
Block a user