mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
remove unneeded iter()
This commit is contained in:
@@ -167,17 +167,16 @@ def get_lr_multiplier(it):
|
|||||||
|
|
||||||
# Go!
|
# Go!
|
||||||
step = 0
|
step = 0
|
||||||
train_iter = iter(train_loader)
|
|
||||||
for step in range(num_iterations):
|
for step in range(num_iterations):
|
||||||
last_step = step == num_iterations - 1
|
last_step = step == num_iterations - 1
|
||||||
|
|
||||||
# evaluate the validation loss
|
# evaluate the validation loss
|
||||||
if last_step or step % eval_every == 0:
|
if last_step or step % eval_every == 0:
|
||||||
model.eval()
|
model.eval()
|
||||||
val_iter = iter(build_val_loader())
|
val_loader = build_val_loader()
|
||||||
losses = []
|
losses = []
|
||||||
for _ in range(eval_steps):
|
for _ in range(eval_steps):
|
||||||
val_inputs, val_targets = next(val_iter)
|
val_inputs, val_targets = next(val_loader)
|
||||||
with torch.no_grad(), autocast_ctx:
|
with torch.no_grad(), autocast_ctx:
|
||||||
loss = model(val_inputs, val_targets)
|
loss = model(val_inputs, val_targets)
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
@@ -214,7 +213,7 @@ for step in range(num_iterations):
|
|||||||
# evaluate the gradient
|
# evaluate the gradient
|
||||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||||
for micro_step in range(grad_accum_steps):
|
for micro_step in range(grad_accum_steps):
|
||||||
train_inputs, train_targets = next(train_iter)
|
train_inputs, train_targets = next(train_loader)
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
loss = model(train_inputs, train_targets)
|
loss = model(train_inputs, train_targets)
|
||||||
train_loss = loss.detach() # for logging
|
train_loss = loss.detach() # for logging
|
||||||
|
|||||||
Reference in New Issue
Block a user