remove unneeded iter()

This commit is contained in:
Eric Silberstein
2025-11-20 15:14:56 -05:00
parent 4a87a0d19f
commit f37d45c21f

View File

@@ -167,17 +167,16 @@ def get_lr_multiplier(it):
# Go!
step = 0
train_iter = iter(train_loader)
for step in range(num_iterations):
last_step = step == num_iterations - 1
# evaluate the validation loss
if last_step or step % eval_every == 0:
model.eval()
val_iter = iter(build_val_loader())
val_loader = build_val_loader()
losses = []
for _ in range(eval_steps):
val_inputs, val_targets = next(val_iter)
val_inputs, val_targets = next(val_loader)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
@@ -214,7 +213,7 @@ for step in range(num_iterations):
# evaluate the gradient
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_iter)
train_inputs, train_targets = next(train_loader)
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach() # for logging