mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
429 lines
22 KiB
Python
429 lines
22 KiB
Python
"""
|
|
GPT model (rewrite, a lot simpler)
|
|
Notable features:
|
|
- rotary embeddings (and no positional embeddings)
|
|
- QK norm
|
|
- untied weights for token embedding and lm_head
|
|
- relu^2 activation in MLP
|
|
- norm after token embedding
|
|
- no learnable params in rmsnorm
|
|
- no bias in linear layers
|
|
- Group-Query Attention (GQA) support for more efficient inference
|
|
- Flash Attention 3 integration
|
|
"""
|
|
|
|
from functools import partial
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from nanochat.common import get_dist_info, print0
|
|
from nanochat.muon import Muon, DistMuon
|
|
from nanochat.adamw import DistAdamW
|
|
|
|
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
|
from nanochat.flash_attention import flash_attn
|
|
|
|
@dataclass
|
|
class GPTConfig:
|
|
sequence_len: int = 1024
|
|
vocab_size: int = 50304
|
|
n_layer: int = 12
|
|
n_head: int = 6 # number of query heads
|
|
n_kv_head: int = 6 # number of key/value heads (GQA)
|
|
n_embd: int = 768
|
|
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
|
# Characters: L=long (full context), S=short (half context)
|
|
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
|
window_pattern: str = "L"
|
|
|
|
|
|
def norm(x):
|
|
# Purely functional rmsnorm with no learnable params
|
|
return F.rms_norm(x, (x.size(-1),))
|
|
|
|
|
|
def apply_rotary_emb(x, cos, sin):
|
|
assert x.ndim == 4 # multihead attention
|
|
d = x.shape[3] // 2
|
|
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
|
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
|
y2 = x1 * (-sin) + x2 * cos
|
|
return torch.cat([y1, y2], 3)
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.n_head = config.n_head
|
|
self.n_kv_head = config.n_kv_head
|
|
self.n_embd = config.n_embd
|
|
self.head_dim = self.n_embd // self.n_head
|
|
assert self.n_embd % self.n_head == 0
|
|
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
|
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
|
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
|
|
|
def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda):
|
|
B, T, C = x.size()
|
|
|
|
# Project the input to get queries, keys, and values
|
|
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
|
|
# Value residual (ResFormer): mix in projected initial embedding for later layers
|
|
if v0 is not None:
|
|
v0_reshaped = v0.view(B, T, self.n_kv_head, self.head_dim)
|
|
v = v + v0_lambda * v0_reshaped
|
|
|
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
|
cos, sin = cos_sin
|
|
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
|
q, k = norm(q), norm(k) # QK norm
|
|
|
|
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
|
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
|
if kv_cache is None:
|
|
# Training: causal attention with optional sliding window
|
|
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
|
else:
|
|
# Inference: use flash_attn_with_kvcache which handles cache management
|
|
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
|
y = flash_attn.flash_attn_with_kvcache(
|
|
q, k_cache, v_cache,
|
|
k=k, v=v,
|
|
cache_seqlens=kv_cache.cache_seqlens,
|
|
causal=True,
|
|
window_size=window_size,
|
|
)
|
|
# Advance position after last layer processes
|
|
if self.layer_idx == kv_cache.n_layers - 1:
|
|
kv_cache.advance(T)
|
|
|
|
# Re-assemble the heads and project back to residual stream
|
|
y = y.contiguous().view(B, T, -1)
|
|
y = self.c_proj(y)
|
|
return y
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
|
|
|
def forward(self, x):
|
|
x = self.c_fc(x)
|
|
x = F.relu(x).square()
|
|
x = self.c_proj(x)
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.attn = CausalSelfAttention(config, layer_idx)
|
|
self.mlp = MLP(config)
|
|
|
|
def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda):
|
|
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache, v0, v0_lambda)
|
|
x = x + self.mlp(norm(x))
|
|
return x
|
|
|
|
|
|
class GPT(nn.Module):
|
|
def __init__(self, config, pad_vocab_size_to=64):
|
|
"""
|
|
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
|
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
|
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
|
"""
|
|
super().__init__()
|
|
self.config = config
|
|
# Compute per-layer window sizes for sliding window attention
|
|
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
|
self.window_sizes = self._compute_window_sizes(config)
|
|
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
|
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
|
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
|
if padded_vocab_size != config.vocab_size:
|
|
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
|
self.transformer = nn.ModuleDict({
|
|
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
|
})
|
|
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
|
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
|
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
|
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
|
# Separate parameters so they can have different optimizer treatment
|
|
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
|
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
|
# Value residual (ResFormer-style): separate embedding for values, mixed into later layers
|
|
# Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow
|
|
# We apply to last 1/4 of layers as the paper shows later layers benefit most
|
|
head_dim = config.n_embd // config.n_head
|
|
kv_dim = config.n_kv_head * head_dim
|
|
self.value_embed = nn.Embedding(padded_vocab_size, kv_dim)
|
|
self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
|
self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers
|
|
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
|
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
|
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
|
# In the future we can dynamically grow the cache, for now it's fine.
|
|
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
|
head_dim = config.n_embd // config.n_head
|
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
|
self.register_buffer("sin", sin, persistent=False)
|
|
|
|
def init_weights(self):
|
|
"""
|
|
Initialize the full model in this one function for maximum clarity.
|
|
|
|
wte (embedding): normal, std=1.0
|
|
lm_head: normal, std=0.001
|
|
for each block:
|
|
attn.c_q: uniform, std=1/sqrt(n_embd)
|
|
attn.c_k: uniform, std=1/sqrt(n_embd)
|
|
attn.c_v: uniform, std=1/sqrt(n_embd)
|
|
attn.c_proj: zeros
|
|
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
|
mlp.c_proj: zeros
|
|
"""
|
|
|
|
# Embedding and unembedding
|
|
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
|
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
|
|
|
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
|
n_embd = self.config.n_embd
|
|
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
|
for block in self.transformer.h:
|
|
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
|
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
|
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
|
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
|
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
|
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
|
|
|
# Per-layer scalars
|
|
with torch.no_grad():
|
|
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
|
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
|
self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init
|
|
|
|
# Value embedding (init like c_v: uniform with same std)
|
|
torch.nn.init.uniform_(self.value_embed.weight, -s, s)
|
|
|
|
# Rotary embeddings
|
|
head_dim = self.config.n_embd // self.config.n_head
|
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
self.cos, self.sin = cos, sin
|
|
|
|
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
|
if self.transformer.wte.weight.device.type == "cuda":
|
|
self.transformer.wte.to(dtype=torch.bfloat16)
|
|
self.value_embed.to(dtype=torch.bfloat16)
|
|
|
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
|
# TODO: bump base theta more? e.g. 100K is more common more recently
|
|
# autodetect the device from model embeddings
|
|
if device is None:
|
|
device = self.transformer.wte.weight.device
|
|
# stride the channels
|
|
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
# stride the time steps
|
|
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
# calculate the rotation frequencies at each (time, channel) pair
|
|
freqs = torch.outer(t, inv_freq)
|
|
cos, sin = freqs.cos(), freqs.sin()
|
|
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
|
return cos, sin
|
|
|
|
def _compute_window_sizes(self, config):
|
|
"""
|
|
Compute per-layer window sizes for sliding window attention.
|
|
|
|
Returns list of (left, right) tuples for FA3's window_size parameter:
|
|
- left: how many tokens before current position to attend to (-1 = unlimited)
|
|
- right: how many tokens after current position to attend to (0 for causal)
|
|
|
|
Pattern string is tiled across layers. Final layer always gets L (full context).
|
|
Characters: L=long (full context), S=short (half context)
|
|
"""
|
|
pattern = config.window_pattern.upper()
|
|
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
|
# Map characters to window sizes
|
|
long_window = config.sequence_len
|
|
short_window = long_window // 2
|
|
char_to_window = {
|
|
"L": (long_window, 0),
|
|
"S": (short_window, 0),
|
|
}
|
|
# Tile pattern across layers
|
|
window_sizes = []
|
|
for layer_idx in range(config.n_layer):
|
|
char = pattern[layer_idx % len(pattern)]
|
|
window_sizes.append(char_to_window[char])
|
|
# Final layer always gets full context
|
|
window_sizes[-1] = (long_window, 0)
|
|
return window_sizes
|
|
|
|
def get_device(self):
|
|
return self.transformer.wte.weight.device
|
|
|
|
def estimate_flops(self):
|
|
"""
|
|
Return the estimated FLOPs per token for the model (forward + backward).
|
|
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
|
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
|
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
|
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
|
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
|
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
|
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
|
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
|
"""
|
|
nparams = sum(p.numel() for p in self.parameters())
|
|
# Exclude non-matmul params: embeddings and per-layer scalars
|
|
nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed.weight.numel() +
|
|
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel())
|
|
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
|
# Sum attention FLOPs per layer, accounting for sliding window
|
|
attn_flops = 0
|
|
for window_size in self.window_sizes:
|
|
window = window_size[0] # (left, right) tuple, we use left
|
|
effective_seq = t if window < 0 else min(window, t)
|
|
attn_flops += 12 * h * q * effective_seq
|
|
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
|
return num_flops_per_token
|
|
|
|
def num_scaling_params(self):
|
|
"""
|
|
Return all of the parameters, same as Chinchilla paper.
|
|
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
|
|
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
|
|
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
|
|
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
|
|
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
|
|
"""
|
|
nparams = sum(p.numel() for p in self.parameters())
|
|
return nparams
|
|
|
|
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
|
model_dim = self.config.n_embd
|
|
ddp, rank, local_rank, world_size = get_dist_info()
|
|
# Separate out all parameters into groups (matrix, embedding, lm_head, value_embed, resid_lambdas, x0_lambdas, v0_lambdas)
|
|
matrix_params = list(self.transformer.h.parameters())
|
|
embedding_params = list(self.transformer.wte.parameters())
|
|
lm_head_params = list(self.lm_head.parameters())
|
|
value_embed_params = list(self.value_embed.parameters())
|
|
resid_params = [self.resid_lambdas]
|
|
x0_params = [self.x0_lambdas]
|
|
v0_params = [self.v0_lambdas]
|
|
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_params) + len(resid_params) + len(x0_params) + len(v0_params)
|
|
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
|
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
|
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
|
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
|
adam_groups = [
|
|
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
|
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
|
dict(params=value_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
|
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
|
dict(params=x0_params, lr=scalar_lr),
|
|
dict(params=v0_params, lr=scalar_lr),
|
|
]
|
|
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
|
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
|
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
|
# Create the Muon optimizer for the linear layers
|
|
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
|
MuonFactory = DistMuon if ddp else Muon
|
|
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
|
# Combine them the two optimizers into one list
|
|
optimizers = [adamw_optimizer, muon_optimizer]
|
|
for opt in optimizers:
|
|
for group in opt.param_groups:
|
|
group["initial_lr"] = group["lr"]
|
|
return optimizers
|
|
|
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
|
B, T = idx.size()
|
|
|
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
|
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
|
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
|
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
|
|
|
# Forward the trunk of the Transformer
|
|
x = self.transformer.wte(idx)
|
|
x = norm(x)
|
|
x0 = x # save initial normalized embedding for x0 residual
|
|
# Value residual (ResFormer): separate value embedding for later layers
|
|
v0 = self.value_embed(idx) # (B, T, kv_dim)
|
|
for i, block in enumerate(self.transformer.h):
|
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
|
v0_for_layer = v0 if i >= self.value_residual_start else None
|
|
x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0_for_layer, self.v0_lambdas[i])
|
|
x = norm(x)
|
|
|
|
# Forward the lm_head (compute logits)
|
|
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
|
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
|
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
|
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
|
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
|
|
|
if targets is not None:
|
|
# training: given the targets, compute and return the loss
|
|
# TODO experiment with chunked cross-entropy?
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
|
return loss
|
|
else:
|
|
# inference: just return the logits directly
|
|
return logits
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
|
"""
|
|
Naive autoregressive streaming inference.
|
|
To make it super simple, let's assume:
|
|
- batch size is 1
|
|
- ids and the yielded tokens are simple Python lists and ints
|
|
"""
|
|
assert isinstance(tokens, list)
|
|
device = self.get_device()
|
|
rng = None
|
|
if temperature > 0:
|
|
rng = torch.Generator(device=device)
|
|
rng.manual_seed(seed)
|
|
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
|
for _ in range(max_tokens):
|
|
logits = self.forward(ids) # (B, T, vocab_size)
|
|
logits = logits[:, -1, :] # (B, vocab_size)
|
|
if top_k is not None:
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
|
if temperature > 0:
|
|
logits = logits / temperature
|
|
probs = F.softmax(logits, dim=-1)
|
|
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
|
else:
|
|
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
|
ids = torch.cat((ids, next_ids), dim=1)
|
|
token = next_ids.item()
|
|
yield token
|