[misc] fix grad ckpt func (#6916)

Former-commit-id: 35e069a52b3d7cfd9b0107574b09265eb2290f0b
This commit is contained in:
hoshi-hiyouga
2025-02-13 00:17:18 +08:00
committed by GitHub
parent 0c0cdc26bc
commit 3a3f4072e5
3 changed files with 17 additions and 13 deletions

View File

@@ -142,21 +142,23 @@ def calculate_mfu(
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
dist.barrier()
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if int(os.getenv("LOCAL_RANK", "0")) == 0:
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__":