fix tf32 warning for deprecated api use

This commit is contained in:
Andrej Karpathy
2025-12-27 22:03:06 +00:00
parent bc51da8bac
commit 49389ecaa8

View File

@@ -158,7 +158,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()