fix silly issue in dataloader, this version is much faster and more portable to mps too

This commit is contained in:
Andrej Karpathy
2025-10-21 17:12:50 +00:00
parent c9ea7a91e2
commit bb71c64579

View File

@@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
bos_token = tokenizer.get_bos_token_id() bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration # scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left token_buffer = deque() # we stream tokens on the right and pop from the left
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# infinite iterator over document batches # infinite iterator over document batches
def document_batches(): def document_batches():
@@ -38,8 +37,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
token_buffer.extend(tokens) token_buffer.extend(tokens)
batch_index += 1 batch_index += 1
# Move tokens from the deque into the scratch buffer # Move tokens from the deque into the scratch buffer
for i in range(needed_tokens): tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
scratch[i] = token_buffer.popleft() scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
# Create the inputs/targets as 1D tensors # Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32) inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:] targets_cpu = scratch[1:]