[misc] fix grad ckpt func (#6916)
Former-commit-id: 35e069a52b3d7cfd9b0107574b09265eb2290f0b
This commit is contained in:
@@ -142,21 +142,23 @@ def calculate_mfu(
|
||||
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||
|
||||
run_exp(args)
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
total_batch_size = batch_size * world_size
|
||||
mfu_value = (
|
||||
result["train_steps_per_second"]
|
||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||
/ compute_device_flops(world_size)
|
||||
)
|
||||
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
total_batch_size = batch_size * world_size
|
||||
mfu_value = (
|
||||
result["train_steps_per_second"]
|
||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||
/ compute_device_flops(world_size)
|
||||
)
|
||||
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user