saving these two attempts

This commit is contained in:
Andrej Karpathy
2026-01-13 17:05:09 +00:00
parent 4610a838a1
commit a6382a6ce8
2 changed files with 426 additions and 0 deletions

193
nanochat/fp8_dynamic.py Normal file
View File

@@ -0,0 +1,193 @@
"""
Linear layer with FP8 training.
Uses dynamic scaling for activations, weights, and gradients (all computed each forward).
Implementation pattern inspired by torchao.float8:
- Uses @torch._dynamo.allow_in_graph on autograd.Function for torch.compile compatibility
- Scales stay as tensors throughout (no .item() calls)
- Inner @torch.compile on the FP8 matmul kernels
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# FP8 format constants
FP8_E4M3_MAX = 448.0 # max representable value in float8_e4m3fn
FP8_E5M2_MAX = 57344.0 # max representable value in float8_e5m2
EPS = 1e-12 # epsilon for numerical stability in scale computation
# -----------------------------------------------------------------------------
# FP8 kernel functions
# These run the FP8 matmul on H100 tensor cores.
# Note: No @torch.compile here - these are called from within a @torch._dynamo.allow_in_graph
# function, which is already inside the outer torch.compile. Nested compile causes issues.
def _fp8_forward_impl(x: torch.Tensor, w: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor):
"""FP8 forward: out = x @ w.T with FP8 quantization."""
x_f8 = (x / x_scale).to(torch.float8_e4m3fn)
w_f8 = (w / w_scale).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x_scale,
scale_b=w_scale,
use_fast_accum=True,
)
return out, x_f8, w_f8
def _fp8_backward_impl(
grad: torch.Tensor,
x_f8: torch.Tensor,
w_f8: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
grad_scale: torch.Tensor,
):
"""FP8 backward: compute gradients for x and w."""
grad = grad.contiguous()
grad_f8 = (grad / grad_scale).to(torch.float8_e5m2)
# grad_x = grad @ W
grad_x = torch._scaled_mm(
grad_f8,
w_f8.T.contiguous().T,
out_dtype=torch.bfloat16,
scale_a=grad_scale,
scale_b=w_scale,
use_fast_accum=False,
)
# grad_w = x.T @ grad (output in float32 for optimizer stability)
grad_w = torch._scaled_mm(
x_f8.T.contiguous(),
grad_f8.T.contiguous().T,
out_dtype=torch.float32,
scale_a=x_scale,
scale_b=grad_scale,
use_fast_accum=False,
).T
return grad_x, grad_w
# -----------------------------------------------------------------------------
# Autograd function with @torch._dynamo.allow_in_graph
# This pattern allows the function to be included in torch.compile graphs.
@torch._dynamo.allow_in_graph
class FP8Matmul(torch.autograd.Function):
"""
FP8 matrix multiply: out = x @ w.T with dynamic scaling.
This autograd.Function is decorated with @torch._dynamo.allow_in_graph,
which tells torch.compile to include it in the compiled graph without
attempting to trace through it (avoiding issues with .item() calls etc).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, w: torch.Tensor, grad_scale: torch.Tensor):
"""
Args:
x: Input activations (2D, contiguous) in bfloat16
w: Weight matrix (2D, contiguous) in bfloat16
grad_scale: Pre-computed scale for gradients in backward
Returns:
out: Result in bfloat16
"""
# Compute scales dynamically as tensors (no .item()!)
x_amax = x.abs().max()
w_amax = w.abs().max()
x_scale = (x_amax / FP8_E4M3_MAX).clamp(min=EPS).float()
w_scale = (w_amax / FP8_E4M3_MAX).clamp(min=EPS).float()
# Run FP8 forward
out, x_f8, w_f8 = _fp8_forward_impl(x, w, x_scale, w_scale)
# Save for backward
ctx.save_for_backward(x_f8, w_f8, x_scale, w_scale, grad_scale)
return out
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x_f8, w_f8, x_scale, w_scale, grad_scale = ctx.saved_tensors
# Run FP8 backward
grad_x, grad_w = _fp8_backward_impl(grad_out, x_f8, w_f8, x_scale, w_scale, grad_scale)
return grad_x, grad_w, None # None for grad_scale
# -----------------------------------------------------------------------------
class LinearFP8(nn.Linear):
"""
Linear layer with FP8 training.
During training, uses FP8 matmul with:
- Dynamic scaling for activations (computed each forward)
- Dynamic scaling for weights (computed each forward)
- Dynamic scaling for gradients (computed from input shape)
During inference, uses standard BF16 matmul.
IMPORTANT: This layer is currently intended uniquely for use as lm_head (the final projection to vocabulary).
=> It assumes input x has shape (B, T, in_features) where B*T is the batch size
for cross-entropy loss. grad_scale is computed as (1 / B*T) / FP8_E5M2_MAX, which
assumes cross-entropy with mean reduction where grad magnitude is ~1/batch_tokens.
Nothing prevents it from being used as a regular layer except that the grad_scale handling would have to be adjusted.
Args:
in_features: Input dimension
out_features: Output dimension
bias: Must be False for now, might support it later
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False):
assert bias is False, "LinearFP8 does not support bias (FP8 matmul has no bias fusion)"
super().__init__(in_features, out_features, bias=False)
# Latest stats (tensors for logging, detached to avoid grad issues)
self._x_amax: torch.Tensor | None = None
self._w_amax: torch.Tensor | None = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
assert x.ndim == 3 and x.shape[2] == self.in_features, f"Expected input shape (B, T, {self.in_features}), got {x.shape}"
B, T, _ = x.shape
# Flatten to 2D for matmul
_x = x.flatten(0, -2) # (B, T, in_features) -> (B*T, in_features)
# Compute grad_scale as a tensor (assumes cross-entropy with mean reduction)
grad_scale = torch.tensor(1.0 / (B * T) / FP8_E5M2_MAX, device=x.device, dtype=torch.float32)
# Run FP8 matmul
out = FP8Matmul.apply(_x, self.weight, grad_scale)
# Reshape back
out = out.reshape(B, T, -1) # (B*T, out_features) -> (B, T, out_features)
# Update stats for logging (detach to avoid keeping grad graph)
self._x_amax = _x.abs().max().detach()
self._w_amax = self.weight.abs().max().detach()
else:
# Standard linear forward (inference)
out = F.linear(x, self.weight.type_as(x))
return out
def get_fp8_stats(self) -> dict:
"""Return the latest FP8 statistics for logging."""
return {
"x_amax": self._x_amax.item() if self._x_amax is not None else None,
"w_amax": self._w_amax.item() if self._w_amax is not None else None,
}

233
nanochat/fp8_static.py Normal file
View File

@@ -0,0 +1,233 @@
"""
Linear layer with FP8 training using static scaling.
All scales (x, w, grad) are set at init time - no runtime computation.
This is the approach used by modded-nanogpt and avoids torch.compile issues.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# FP8 format constants
FP8_E4M3_MAX = 448.0 # max representable value in float8_e4m3fn
FP8_E5M2_MAX = 57344.0 # max representable value in float8_e5m2
# -----------------------------------------------------------------------------
# Custom FP8 matmul operators (based on modded-nanogpt)
# These use torch.library.custom_op for torch.compile compatibility
#
# All scales are Python floats passed at call time (but set statically at init).
@torch.library.custom_op("nanochat::fp8_mm_static", mutates_args=())
def fp8_mm_op(x: torch.Tensor, w: torch.Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""FP8 matrix multiply: out = x @ w.T with FP8 quantization.
Args:
x: Input activations (2D, contiguous)
w: Weight matrix (2D, contiguous)
x_s: Scale for x (x_amax / 448)
w_s: Scale for w (w_amax / 448)
grad_s: Scale for gradients in backward (grad_amax / 57344)
Returns:
out: Result in bfloat16
x_f8: x quantized to FP8 (saved for backward)
w_f8: w quantized to FP8 (saved for backward)
"""
@torch.compile
def impl(x: torch.Tensor, w: torch.Tensor):
assert x.is_contiguous() and w.is_contiguous()
x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(x_s, dtype=torch.float32),
scale_b=x.new_tensor(w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
@fp8_mm_op.register_fake
def _(x: torch.Tensor, w: torch.Tensor, x_s: float, w_s: float, grad_s: float):
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1]
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
@torch.library.custom_op("nanochat::fp8_mm_static_backward", mutates_args=())
def fp8_mm_backward_op(g: torch.Tensor, x_f8: torch.Tensor, w_f8: torch.Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[torch.Tensor, torch.Tensor]:
"""Backward pass for FP8 matmul.
Args:
g: Gradient of output (dL/dy)
x_f8: Saved FP8 activations from forward
w_f8: Saved FP8 weights from forward
x_s, w_s, grad_s: Scale factors
Returns:
grad_x: Gradient w.r.t. input (bfloat16)
grad_w: Gradient w.r.t. weights (float32 for optimizer)
"""
@torch.compile
def impl(grad: torch.Tensor, x_f8: torch.Tensor, w_f8: torch.Tensor):
grad = grad.contiguous()
x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
# grad_x = grad @ W
grad_x = torch._scaled_mm(
grad_f8,
w_f8.T.contiguous().T,
out_dtype=torch.bfloat16,
scale_a=grad_inv_s,
scale_b=w_inv_s,
use_fast_accum=False,
)
# grad_w = x.T @ grad (output in float32 for optimizer stability)
grad_w = torch._scaled_mm(
x_f8.T.contiguous(),
grad_f8.T.contiguous().T,
out_dtype=torch.float32,
scale_a=x_inv_s,
scale_b=grad_inv_s,
use_fast_accum=False,
).T
return grad_x, grad_w
return impl(g, x_f8, w_f8)
@fp8_mm_backward_op.register_fake
def _(g: torch.Tensor, x_f8: torch.Tensor, w_f8: torch.Tensor, x_s: float, w_s: float, grad_s: float):
return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32)
def _fp8_mm_backward(ctx, grad_out: torch.Tensor, *_):
x_f8, w_f8 = ctx.saved_tensors
x_s, w_s, grad_s = ctx.scales
grad_x, grad_w = torch.ops.nanochat.fp8_mm_static_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
return grad_x, grad_w, None, None, None
def _fp8_mm_setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
*_, x_s, w_s, grad_s = inputs
_, x_f8, w_f8 = output
ctx.save_for_backward(x_f8, w_f8)
ctx.scales = x_s, w_s, grad_s
ctx.set_materialize_grads(False)
fp8_mm_op.register_autograd(_fp8_mm_backward, setup_context=_fp8_mm_setup_context)
# -----------------------------------------------------------------------------
class LinearFP8(nn.Linear):
"""
Linear layer with FP8 training using static scaling.
Scales for x and w are set at initialization - no runtime amax computation.
Grad scale is computed dynamically from input shape.
IMPORTANT: This layer assumes it is used as an unembedding/classifier layer
where the output goes directly into softmax cross-entropy loss. This assumption
allows us to compute grad_scale dynamically: the gradient of cross-entropy w.r.t.
logits is (softmax - one_hot), which has element-wise amax of 1. With mean reduction
over B*T tokens, the gradient amax is 1/(B*T).
During training, uses FP8 matmul with static x/w scales and dynamic grad scale.
During inference, uses standard BF16 matmul.
Args:
in_features: Input dimension
out_features: Output dimension (vocabulary size)
bias: Must be False (FP8 matmul has no bias fusion)
x_scale: Scale for activations = expected_x_amax / 448. Required.
w_scale: Scale for weights = expected_w_amax / 448. Required.
monitor: If True, record actual amax values each forward for get_fp8_stats().
Adds small overhead from .item() calls. Default False.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
x_scale: float = None,
w_scale: float = None,
monitor: bool = False,
):
assert bias is False, "LinearFP8 does not support bias (FP8 matmul has no bias fusion)"
assert x_scale is not None, "x_scale is required (expected_x_amax / 448)"
assert w_scale is not None, "w_scale is required (expected_w_amax / 448)"
super().__init__(in_features, out_features, bias=False)
self.x_scale = x_scale
self.w_scale = w_scale
self.monitor = monitor
# Observed amax values (updated each forward when monitor=True)
self._x_amax: float | None = None
self._w_amax: float | None = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
# This layer assumes (B, T, C) input for classifier/lm_head
assert x.ndim == 3, f"Expected input shape (B, T, {self.in_features}), got {x.shape}"
B, T, C = x.shape
assert C == self.in_features, f"Expected C={self.in_features}, got {C}"
# Compute grad_scale dynamically from batch size.
# Assumption: this is an unembedding layer going into cross-entropy loss.
# Gradient of CE w.r.t. logits = (softmax - one_hot), element-wise amax = 1.
# With mean reduction over B*T tokens, grad amax = 1/(B*T).
grad_amax = 1.0 / (B * T)
grad_scale = grad_amax / FP8_E5M2_MAX
# Flatten to 2D, do the matmul, reshape back
_x = x.flatten(0, -2) # (B, T, C) -> (B*T, C)
out, _, _ = torch.ops.nanochat.fp8_mm_static(_x, self.weight, self.x_scale, self.w_scale, grad_scale)
out = out.reshape(B, T, -1) # (B*T, V) -> (B, T, V)
# Record actual amax for monitoring (detect if values exceed static scale assumptions)
if self.monitor:
self._x_amax = _x.detach().abs().max().item()
self._w_amax = self.weight.detach().abs().max().item()
else:
# Standard linear forward (inference)
out = F.linear(x, self.weight.type_as(x))
return out
def get_fp8_stats(self) -> dict:
"""Return observed amax values for monitoring.
Compare these against the expected amax implied by static scales:
- expected_x_amax = x_scale * 448
- expected_w_amax = w_scale * 448
If observed > expected, values are being clipped and you should increase the scale.
"""
return {
"x_amax": self._x_amax,
"w_amax": self._w_amax,
}
def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, out_features={self.out_features}, bias=False, "
f"x_scale={self.x_scale:.2e}, w_scale={self.w_scale:.2e}"
)