Former-commit-id: 76177039c8f9ef5a63724a339dae6195d89fa215
This commit is contained in:
hiyouga
2024-09-08 23:18:08 +08:00
parent 3259397f89
commit 3cbc9109ea
2 changed files with 59 additions and 47 deletions

View File

@@ -18,6 +18,7 @@ import os
import fire
import torch
import torch.distributed as dist
from transformers import AutoConfig
from llamafactory.train.tuner import run_exp
@@ -28,7 +29,7 @@ BASE = 2 # gemm (add + mul)
def compute_model_flops(
model_name_or_path: str,
batch_size: int,
total_batch_size: int,
seq_length: int,
include_backward: bool = True,
include_recompute: bool = False,
@@ -48,7 +49,7 @@ def compute_model_flops(
# mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size
@@ -56,15 +57,15 @@ def compute_model_flops(
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module
embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = batch_size * seq_length * embedding_flops_per_token
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False:
embedding_flops *= 2
@@ -85,17 +86,19 @@ def compute_model_flops(
return total_flops
def compute_device_flops() -> float:
def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * device_count
return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * device_count
return 312 * 1e12 * world_size
elif "V100" in device_name:
return 125 * 1e12 * device_count
return 125 * 1e12 * world_size
elif "4090" in device_name:
return 98 * 1e12 * device_count
return 98 * 1e12 * world_size
else:
raise NotImplementedError("Device not supported: {}.".format(device_name))
@@ -140,10 +143,16 @@ def calculate_mfu(
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, batch_size, seq_length)
/ compute_device_flops()
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print("MFU: {:.2f}%".format(mfu_value * 100))