diff --git a/nanochat/common.py b/nanochat/common.py index faf9144..44760f9 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -201,51 +201,76 @@ class DummyWandb: def finish(self): pass -# hardcoded BF16 peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC +# hardcoded BF16 peak flops for various GPUs # inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py +# and PR: https://github.com/karpathy/nanochat/pull/147 def get_peak_flops(device_name: str) -> float: - if "A100" in device_name: - # data from https://www.nvidia.com/en-us/data-center/a100/ - return 312e12 - elif "H100" in device_name: - # data from https://www.nvidia.com/en-us/data-center/h100/ - # NOTE: Specifications are one-half lower without sparsity. - if "NVL" in device_name: - return 835e12 - elif "PCIe" in device_name: - return 756e12 - else: # for H100 SXM and other variants - return 989e12 - elif "H200" in device_name: - # data from https://www.nvidia.com/en-us/data-center/h200/ - return 989e12 - elif "B200" in device_name: - # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 + name = device_name.lower() + + # --- NVIDIA Blackwell --- + if "gb200" in name or "grace blackwell" in name: + return 2.5e15 + if "b200" in name: return 2.25e15 - elif "MI355X" in device_name: - # MI355X data from https://www.amd.com/en/products/accelerators/instinct/mi350/mi355x.html - return 2500e12 - elif "MI300X" in device_name or "MI325X" in device_name: - # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html - # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html - return 1300e12 - elif "MI250X" in device_name: - # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD) - return 191.5e12 - elif "Data Center GPU Max 1550" in device_name: - # Also known as Ponte Vecchio (PVC). - # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html - # Dot Product Accumulate Systolic (DPAS): - # - Freq: 1300MHz - # - #ops: 512 - # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16) - # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16) + if "b100" in name: + return 1.8e15 + + # --- NVIDIA Hopper (H100/H200/H800) --- + if "h200" in name: + if "nvl" in name or "pcie" in name: + return 836e12 + return 989e12 # H200 SXM + if "h100" in name: + if "nvl" in name: + return 835e12 + if "pcie" in name: + return 756e12 + return 989e12 # H100 SXM + if "h800" in name: + if "nvl" in name: + return 989e12 + return 756e12 # H800 PCIe + + # --- NVIDIA Ampere data center --- + if "a100" in name or "a800" in name: + return 312e12 + if "a40" in name: + return 149.7e12 + if "a30" in name: + return 165e12 + + # --- NVIDIA Ada data center --- + if "l40s" in name or "l40-s" in name or "l40 s" in name: + return 362e12 + if "l4" in name: + return 121e12 + + # --- AMD CDNA accelerators --- + if "mi355" in name: + return 2.5e15 + if "mi325" in name or "mi300x" in name: + return 1.3074e15 + if "mi300a" in name: + return 980.6e12 + if "mi250x" in name: + return 383e12 + if "mi250" in name: + return 362.1e12 + + # --- Intel --- + if "data center gpu max 1550" in name: + # Ponte Vecchio (PVC) - dynamic based on compute units max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units return 512 * max_comp_units * 1300 * 10**6 - elif "l40s" in device_name: - # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413" - return 362e12 - else: # for other GPU types, assume A100 - logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") - return 312e12 + # --- Consumer RTX (for hobbyists) --- + if "5090" in name: + return 209.5e12 + if "4090" in name: + return 165.2e12 + if "3090" in name: + return 71e12 + + # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess + logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%") + return float('inf')