From f37d45c21fded8660812c7c38a9a13275f123fbb Mon Sep 17 00:00:00 2001 From: Eric Silberstein Date: Thu, 20 Nov 2025 15:14:56 -0500 Subject: [PATCH] remove unneeded iter() --- scripts/chat_sft.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbeb1f9..f93a6e6 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -167,17 +167,16 @@ def get_lr_multiplier(it): # Go! step = 0 -train_iter = iter(train_loader) for step in range(num_iterations): last_step = step == num_iterations - 1 # evaluate the validation loss if last_step or step % eval_every == 0: model.eval() - val_iter = iter(build_val_loader()) + val_loader = build_val_loader() losses = [] for _ in range(eval_steps): - val_inputs, val_targets = next(val_iter) + val_inputs, val_targets = next(val_loader) with torch.no_grad(), autocast_ctx: loss = model(val_inputs, val_targets) losses.append(loss) @@ -214,7 +213,7 @@ for step in range(num_iterations): # evaluate the gradient num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen for micro_step in range(grad_accum_steps): - train_inputs, train_targets = next(train_iter) + train_inputs, train_targets = next(train_loader) with autocast_ctx: loss = model(train_inputs, train_targets) train_loss = loss.detach() # for logging