mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
clean up train/val loader in sft for consistency with mid/base
This commit is contained in:
@@ -167,17 +167,16 @@ def get_lr_multiplier(it):
|
||||
|
||||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
for step in range(num_iterations):
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
val_iter = iter(build_val_loader())
|
||||
val_loader = build_val_loader()
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
val_inputs, val_targets = next(val_iter)
|
||||
val_inputs, val_targets = next(val_loader)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
@@ -214,7 +213,7 @@ for step in range(num_iterations):
|
||||
# evaluate the gradient
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
train_inputs, train_targets = next(train_loader)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
|
||||
Reference in New Issue
Block a user