mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
contiguous views and single HtoD transfer for inputs/targets much cleaner
This commit is contained in:
@@ -154,6 +154,16 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
|
||||
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
||||
# This gives us contiguous views and a single HtoD transfer
|
||||
use_cuda = device == "cuda"
|
||||
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
|
||||
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
||||
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
|
||||
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
||||
inputs = gpu_buffer[:B * T].view(B, T)
|
||||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(B):
|
||||
@@ -185,13 +195,16 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
use_cuda = device == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
|
||||
# Convert rows to tensor and copy slices to pinned buffer (CPU work)
|
||||
row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary
|
||||
cpu_inputs.copy_(row_data[:, :-1])
|
||||
cpu_targets.copy_(row_data[:, 1:])
|
||||
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
# Single HtoD copy into persistent GPU buffer and yield
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
||||
yield inputs, targets, state_dict
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
|
||||
Reference in New Issue
Block a user