upgrading all other files to be able to use cpu/mps as well as cuda. various minor other changes ,e.g. changing max_iterations to num_iterations in sft script for consistency in naming

This commit is contained in:
karpathy
2025-10-20 10:15:17 -07:00
parent a09ac812ed
commit 2e9669e03a
9 changed files with 298 additions and 90 deletions

View File

@@ -34,6 +34,7 @@ device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 2048
device_batch_size = 32
unembedding_lr = 0.004
@@ -116,6 +117,7 @@ def mid_data_generator(split):
token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
# Accumulate enough tokens for one iteration before yielding
while len(token_buffer) < needed_tokens:
@@ -127,6 +129,10 @@ def mid_data_generator(split):
cursor -= dataset_size # wrap around for another epoch
if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Stopping condition to respect num_iterations, if given
it += 1
if num_iterations > 0 and it >= num_iterations:
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
@@ -135,7 +141,10 @@ def mid_data_generator(split):
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
if split == "train":
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets
train_loader = mid_data_generator("train")