Compare commits

23 Commits

Author SHA1 Message Date
Andrej Karpathy
05a051dbe9 fix tokenization bug, there should be no space before first letter. sigh 2025-10-24 15:06:06 +00:00
Andrej Karpathy
8892470f29 add the SpellingBee task so that nanochat can count r in strawberry etc. along the way we had to add a bunch of new functionality, e.g. extend the calculator to support the count function of python. possibly the current TaskMixture uses way too many synthetic examples of SpellingBee because the eval gives us exactly 100% performance on spelling. We can tune this later to reclaim some wall clock time here I think 2025-10-24 14:02:48 +00:00
Andrej Karpathy
81597cd616 move the lr schedule args up in base_train so they are tunable in configurator 2025-10-24 13:27:31 +00:00
Andrej Karpathy
cc3636b01c allow the tokenizer visualize_tokenization to also print the exact token id. you can never be paranoid enough 2025-10-24 13:27:05 +00:00
Andrej Karpathy
5eeb2b6ef9 experiment: looking to 'hire' a nanochat repo czar to help the repo, mentioning in readme 2025-10-22 16:55:54 +00:00
Andrej Karpathy
2dda5c4c8d Merge branch 'ulanch-fix/ios-safari-input-overlap' 2025-10-22 16:26:35 +00:00
Andrej Karpathy
80b203ea59 also bump run1000.sh to new uv sync 2025-10-22 16:25:36 +00:00
Luke Stanley
917c858136 Updates lockfile with CPU package support without overwriting other architectures 2025-10-22 16:25:36 +00:00
Luke Stanley
db1d5b595d Git ignore eval_bundle 2025-10-22 16:25:36 +00:00
Luke Stanley
dd9387b362 Fix GPU-less CPU use on Linux with specific Torch indexes 2025-10-22 16:25:36 +00:00
Luke Stanley
32571664b1 Fix Torch crash caused by pinning on CPU 2025-10-22 16:25:36 +00:00
Andrej Karpathy
51e70f0d3c Merge branch 'lukestanley-fix-cpu-support-with-extras' 2025-10-22 16:11:15 +00:00
Andrej Karpathy
48387cd895 also bump run1000.sh to new uv sync 2025-10-22 16:08:31 +00:00
ulanch
796f84527f fix(ui): prevent iOS Safari toolbar from covering input on initial load 2025-10-21 17:34:40 -07:00
Luke Stanley
7a52f9bfbb Updates lockfile with CPU package support without overwriting other architectures 2025-10-21 23:14:34 +00:00
Luke Stanley
760af62e11 Git ignore eval_bundle 2025-10-21 23:14:34 +00:00
Luke Stanley
901b075605 Fix GPU-less CPU use on Linux with specific Torch indexes 2025-10-21 23:14:16 +00:00
Luke Stanley
defd1246aa Fix Torch crash caused by pinning on CPU 2025-10-21 20:28:10 +00:00
Andrej
2e938530ce delete spurious torch.empty allocation in adamw
fix: remove unnecessary tensor allocation in DistAdamW optimizer
2025-10-21 11:35:17 -07:00
Andrej Karpathy
a088b7a6ec use enable_gqa of pytorch sdpa, allows us to delete some code, didnt realize it's available 2025-10-21 18:07:33 +00:00
Andrej Karpathy
94ee507054 quick fix base eval due to fewshot requirement 2025-10-21 17:56:08 +00:00
Andrej
33e8a27f91 Merge karpathy/cpu-mps-dev , adding the ability to run on CPU, on MPS, or on CUDA, with autodetect. Gnarly PR, nonzero chance I broke something.
add cpu|mps support
2025-10-21 10:26:04 -07:00
Sermet Pekin
49cd02f283 fix: remove unnecessary tensor allocation in DistAdamW optimizer
fix: remove unnecessary tensor allocation in DistAdamW optimizer
2025-10-20 12:03:26 +03:00
19 changed files with 714 additions and 112 deletions

1
.gitignore vendored
View File

@@ -4,3 +4,4 @@ __pycache__/
rustbpe/target/ rustbpe/target/
dev-ignore/ dev-ignore/
report.md report.md
eval_bundle/

View File

@@ -125,6 +125,8 @@ python -m pytest tests/test_rustbpe.py -v -s
nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
I am looking for someone to be the "nanochat repo czar" to help me manage the nanochat repo and its issues and PRs and be the first round of defense. Examples of work include merging simple fixes (docs, typos, clear and simple bugs etc.), rejecting vibe coded PRs, managing the Issues/PRs, doing brief "sanity check testing" of PRs on the two officially supported platforms (Linux/GPU and Macbook), organizing information into brief updates and highlights for me. We'd be in touch on DMs on Discord or X or whatever is easiest. For your services to the repo you will be listed and linked to under acknowledgements as the nanochat repo czar. Position is at-will so you can contribute for a while and then "resign" at any time later, totally ok and thank you for your help, just me know. Apply via DM to me on X, thank you!
## Acknowledgements ## Acknowledgements
- The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining. - The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining.

View File

@@ -14,7 +14,7 @@ NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv [ -d ".venv" ] || uv venv
uv sync uv sync --extra cpu
source .venv/bin/activate source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy WANDB_RUN=dummy
@@ -53,7 +53,7 @@ python -m scripts.base_train \
--sample_every=50 \ --sample_every=50 \
--num_iterations=50 --num_iterations=50
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
python -m scripts.base_eval --max-per-task=5 python -m scripts.base_eval --max-per-task=16
# midtraining # midtraining
python -m scripts.mid_train \ python -m scripts.mid_train \

View File

@@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer):
grad_slices = [] grad_slices = []
for group in self.param_groups: for group in self.param_groups:
params: list[Tensor] = group["params"] params: list[Tensor] = group["params"]
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
for base_i in range(len(params)): for base_i in range(len(params)):
grad = params[base_i].grad grad = params[base_i].grad
rank_size = grad.shape[0] // world_size rank_size = grad.shape[0] // world_size

View File

@@ -5,6 +5,8 @@ Common utilities for nanochat.
import os import os
import re import re
import logging import logging
import fcntl
import urllib.request
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -56,6 +58,44 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def download_file_with_lock(url, filename):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
if os.path.exists(file_path):
return file_path
with open(lock_path, 'w') as lock_file:
# Only a single rank can acquire this lock
# All other ranks block until it is released
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
if os.path.exists(file_path):
return file_path
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read().decode('utf-8')
with open(file_path, 'w') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Clean up the lock file after the lock is released
try:
os.remove(lock_path)
except OSError:
pass # Ignore if already removed by another process
return file_path
def print0(s="",**kwargs): def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0)) ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0: if ddp_rank == 0:

View File

@@ -38,7 +38,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
batch_index += 1 batch_index += 1
# Move tokens from the deque into the scratch buffer # Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)] tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True) # CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
# Create the inputs/targets as 1D tensors # Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32) inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:] targets_cpu = scratch[1:]

View File

@@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3):
return None return None
def use_calculator(expr): def use_calculator(expr):
"""Evaluate a math expression safely.""" """
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "") expr = expr.replace(",", "")
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
# Check if it's a pure math expression (old behavior)
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
# Check if it's a string operation we support
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None return None
if "**" in expr: # for now disallow power operator, could be very expensive
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None return None
# Only allow .count() method for now (can expand later)
if '.count(' not in expr:
return None
# Evaluate with timeout
return eval_with_timeout(expr) return eval_with_timeout(expr)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin):
out = out.to(x.dtype) # ensure input/output dtypes match out = out.to(x.dtype) # ensure input/output dtypes match
return out return out
def repeat_kv(x, n_rep):
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
if n_rep == 1:
return x
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
@@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module):
Tq = q.size(2) # number of queries in this forward pass Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Apply MQA: replicate the key/value heads for each query head
nrep = self.n_head // self.n_kv_head
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
# Attention: queries attend to keys/values autoregressively. A few cases to handle: # Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk: if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention # During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk # And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1: elif Tq == 1:
# During inference but with a single query in this forward pass: # During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache # The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False) y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else: else:
# During inference AND we have a chunk of queries in this forward pass: # During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix) # First, each query attends to all the cached keys/values (i.e. full prefix)
@@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module):
attn_mask[:, :prefix_len] = True attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk # Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Re-assemble the heads side by side and project back to residual stream # Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1) y = y.transpose(1, 2).contiguous().view(B, T, -1)

View File

@@ -341,16 +341,19 @@ class RustBPETokenizer:
mask = mask[:max_tokens] mask = mask[:max_tokens]
return ids, mask return ids, mask
def visualize_tokenization(self, ids, mask): def visualize_tokenization(self, ids, mask, with_token_id=False):
"""Small helper function useful in debugging: visualize the tokenization of render_conversation""" """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
RED = '\033[91m' RED = '\033[91m'
GREEN = '\033[92m' GREEN = '\033[92m'
RESET = '\033[0m' RESET = '\033[0m'
GRAY = '\033[90m'
tokens = [] tokens = []
for i, (token_id, mask_val) in enumerate(zip(ids, mask)): for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
token_str = self.decode([token_id]) token_str = self.decode([token_id])
color = GREEN if mask_val == 1 else RED color = GREEN if mask_val == 1 else RED
tokens.append(f"{color}{token_str}{RESET}") tokens.append(f"{color}{token_str}{RESET}")
if with_token_id:
tokens.append(f"{GRAY}({token_id}){RESET}")
return '|'.join(tokens) return '|'.join(tokens)
def render_for_completion(self, conversation): def render_for_completion(self, conversation):

View File

@@ -2,7 +2,7 @@
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
<title>NanoChat</title> <title>NanoChat</title>
<link rel="icon" type="image/svg+xml" href="/logo.svg"> <link rel="icon" type="image/svg+xml" href="/logo.svg">
<style> <style>
@@ -18,7 +18,7 @@
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol"; font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
background-color: #ffffff; background-color: #ffffff;
color: #111827; color: #111827;
min-height: 100vh; min-height: 100dvh;
margin: 0; margin: 0;
display: flex; display: flex;
flex-direction: column; flex-direction: column;
@@ -144,6 +144,7 @@
.input-container { .input-container {
background-color: #ffffff; background-color: #ffffff;
padding: 1rem; padding: 1rem;
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
} }
.input-wrapper { .input-wrapper {

View File

@@ -44,19 +44,35 @@ python_files = ["test_*.py"]
python_classes = ["Test*"] python_classes = ["Test*"]
python_functions = ["test_*"] python_functions = ["test_*"]
# target torch to cuda 12.8 # target torch to cuda 12.8 or CPU
[tool.uv.sources] [tool.uv.sources]
torch = [ torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" }, { index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }, { index = "pytorch-cu128", extra = "gpu" },
] ]
[[tool.uv.index]] [[tool.uv.index]]
name = "pytorch-cpu" name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true explicit = true
[[tool.uv.index]] [project.optional-dependencies]
name = "pytorch-cu128" cpu = [
url = "https://download.pytorch.org/whl/cu128" "torch>=2.8.0",
explicit = true ]
gpu = [
"torch>=2.8.0",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]

View File

@@ -10,7 +10,7 @@ export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv [ -d ".venv" ] || uv venv
uv sync uv sync --extra gpu
source .venv/bin/activate source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy WANDB_RUN=dummy

View File

@@ -49,6 +49,9 @@ unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam) weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon) matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled) grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
# Evaluation # Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on eval_tokens = 20*524288 # number of tokens to evaluate val loss on
@@ -85,7 +88,7 @@ print0(f"Vocab size: {vocab_size:,}")
num_layers = depth num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_kv_heads = num_heads # 1:1 MQA ratio num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}") print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}") print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}") print0(f"num_heads: {num_heads}")
@@ -151,10 +154,6 @@ x, y = next(train_loader) # kick off load of the very first batch of data
# Set up hyperparameter schedulers # Set up hyperparameter schedulers
# Learning rate scheduler # Learning rate scheduler
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
def get_lr_multiplier(it): def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations) warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations) warmdown_iters = round(warmdown_ratio * num_iterations)

View File

@@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU from tasks.mmlu import MMLU
from tasks.arc import ARC from tasks.arc import ARC
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.spellingbee import SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate) # Generative evaluation loop (we go one problem at a time, sample, evaluate)
@@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"), 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"), 'GSM8K': partial(GSM8K, subset="main", split="test"),
'SpellingBee': partial(SpellingBee, size=256, split="test"),
}[task_name] }[task_name]
task_object = task_module() task_object = task_module()
# Run the evaluation # Run the evaluation
@@ -204,13 +206,14 @@ if __name__ == "__main__":
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
# Get the tasks to evaluate on # Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval'] all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
baseline_accuracies = { baseline_accuracies = {
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25% 'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0% 'GSM8K': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0% 'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
} }
task_names = all_tasks if args.task_name is None else args.task_name.split('|') task_names = all_tasks if args.task_name is None else args.task_name.split('|')

View File

@@ -28,6 +28,7 @@ from tasks.arc import ARC
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# SFT Hyperparameters # SFT Hyperparameters
@@ -86,7 +87,9 @@ train_ds = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@@ -100,7 +101,9 @@ train_dataset = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture([ val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
@@ -119,7 +122,8 @@ def mid_data_generator(split):
assert dataset_size > 0 assert dataset_size > 0
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque() token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) # CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter it = 0 # iteration counter
while True: while True:

View File

@@ -23,7 +23,7 @@ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# create a .venv local virtual environment (if it doesn't exist) # create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv [ -d ".venv" ] || uv venv
# install the repo dependencies # install the repo dependencies
uv sync uv sync --extra gpu
# activate venv so that `python` uses the project's venv instead of system python # activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate source .venv/bin/activate

305
tasks/spellingbee.py Normal file
View File

@@ -0,0 +1,305 @@
"""
Task intended to make nanochat better in spelling and counting, for example:
"How many r are in strawberry?" -> 3
An interesting part of this task is that we will get the assistant to
solve the problem using a combination of manual counting and Python.
This is a good problem solving "instinct" to mix into the model and RL
may further refine it to trust one over the other. If we were extra fancy
(which we could/should be) we'd add small errors here and there to allow
the model also learn recoveries. We can do this in future versions.
There are two tasks in this file:
1. SpellingBee: Counting the number of occurrences of a letter in a word
2. SimpleSpelling: Simply spelling words
(1) is the goal, but (2) exists as a highly condensed version of the part
that makes (1) difficult, which is word spelling. This is non-trivial for an
LLM because it has to learn how every token (a little semantic chunk/atom)
maps to the sequence of individual characters that make it up. Larger models
learn this eventually on their own, but if we want this capability to exist
in smaller models, we have to actively encourage it by over-representing it
in the training data. Midtraining is a good place to do this.
To preview a few example conversations, run:
python -m tasks.spellingbee
"""
import re
import random
from tasks.common import Task
from nanochat.common import download_file_with_lock
# Letters of the alphabet
LETTERS = "abcdefghijklmnopqrstuvwxyz"
# A list of 370K English words of large variety
WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
# Identical to gsm8k's answer extraction
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
"""
match = ANSWER_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
return None
# User message templates for data augmentation
USER_MSG_TEMPLATES = [
"How many {letter} are in the word {word}",
"How many {letter} are in {word}",
"Count the number of {letter} in {word}",
"How many times does {letter} appear in {word}",
"What's the count of {letter} in {word}",
"In the word {word}, how many {letter} are there",
"How many letter {letter} are in the word {word}",
"Count how many {letter} appear in {word}",
"Tell me the number of {letter} in {word}",
"How many occurrences of {letter} are in {word}",
"Find the count of {letter} in {word}",
"Can you count the {letter} letters in {word}",
"What is the frequency of {letter} in {word}",
"How many {letter}s are in {word}",
"How many {letter}'s are in {word}",
"Count all the {letter} in {word}",
"How many times is {letter} in {word}",
"Number of {letter} in {word}",
"Total count of {letter} in {word}",
"How many {letter} does {word} have",
"How many {letter} does {word} contain",
"What's the number of {letter} in {word}",
"{word} has how many {letter}",
"In {word}, count the {letter}",
"How many {letter} appear in {word}",
"Count the {letter} in {word}",
"Give me the count of {letter} in {word}",
"How many instances of {letter} in {word}",
"Show me how many {letter} are in {word}",
"Calculate the number of {letter} in {word}",
# Spanish
"¿Cuántas {letter} hay en {word}?",
"¿Cuántas veces aparece {letter} en {word}?",
"Cuenta las {letter} en {word}",
"¿Cuántas letras {letter} tiene {word}?",
# Chinese (Simplified)
"{word}中有多少个{letter}",
"{word}里有几个{letter}",
"数一下{word}中的{letter}",
"{word}这个词里有多少{letter}",
# Korean
"{word}{letter}가 몇 개 있나요",
"{word}에서 {letter}의 개수는",
"{word}{letter}가 몇 번 나오나요",
"{word}라는 단어에 {letter}가 몇 개",
# French
"Combien de {letter} dans {word}",
"Combien de fois {letter} apparaît dans {word}",
"Compte les {letter} dans {word}",
# German
"Wie viele {letter} sind in {word}",
"Wie oft kommt {letter} in {word} vor",
"Zähle die {letter} in {word}",
# Japanese
"{word}{letter}は何個ありますか",
"{word}の中に{letter}がいくつ",
"{word}{letter}が何回出てくる",
]
class SpellingBee(Task):
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
self.size = size
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f:
words = [line.strip() for line in f]
self.words = words
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
word = rng.choice(self.words)
# pick a letter from it (90%) or a random letter (10%)
letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS)
# get the correct answer by simply counting
count = word.count(letter)
# create a user message, with a bunch of variations as data augmentation
template = rng.choice(USER_MSG_TEMPLATES)
# 30% chance to lowercase the template (lazy people don't use shift)
if rng.random() < 0.3:
template = template.lower()
quote_options = ['', "'", '"']
letter_quote = rng.choice(quote_options) # is the letter quoted?
word_quote = rng.choice(quote_options) # is the word quoted?
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
word_wrapped = f"{word_quote}{word}{word_quote}"
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
if rng.random() < 0.5: # 50% of people don't even use question marks
user_msg += "?"
# Now create the ideal assistant response - build as parts (text + tool calls)
assistant_parts = []
word_letters = ",".join(list(word))
manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first.
First spell the word out:
{word}:{word_letters}
Then count the occurrences of '{letter}':
"""
# Little simulated loop of the solution process
# TODO: This is where the fun starts, we could simulate cute little mistakes
# and get the model to review its work and recover from them.
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
running_count = 0
for i, char in enumerate(word, 1):
if char == letter:
running_count += 1
# note: there deliberately cannot be a space here between i and char
# because this would create a different token! (e.g. " a" and "a" are different tokens)
manual_text += f"{i}:{char} hit! count={running_count}\n"
else:
manual_text += f"{i}:{char}\n"
manual_text += f"\nThis gives us {running_count}."
assistant_parts.append({"type": "text", "text": manual_text})
# Part 2: Python verification
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
# Part 3: Python tool call
python_expr = f"'{word}'.count('{letter}')"
assistant_parts.append({"type": "python", "text": python_expr})
# Part 4: Python output
assistant_parts.append({"type": "python_output", "text": str(count)})
# Part 5: Final answer
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
# return the full conversation
messages = [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_parts}
]
conversation = {
"messages": messages,
}
return conversation
def evaluate(self, conversation, assistant_response):
"""
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
Identical to gsm8k's evaluation.
"""
assert isinstance(assistant_response, str), "Assuming simple string response for now"
# First extract the ground truth answer from the conversation
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
# The last text part contains the final answer with ####
last_text_part = assistant_message['content'][-1]['text']
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)
# Compare and return the success as int
is_correct = int(pred_num == ref_num)
return is_correct
def reward(self, conversation, assistant_response):
""" Use simple 0-1 reward just like gsm8k."""
is_correct = self.evaluate(conversation, assistant_response)
is_correct_float = float(is_correct)
return is_correct_float
class SimpleSpelling(Task):
"""Much simpler task designed to get the model to just practice spelling words."""
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
self.size = size
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f:
words = [line.strip() for line in f]
rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task
self.words = words
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
word = rng.choice(self.words)
word_letters = ",".join(list(word))
# return the full conversation
messages = [
{"role": "user", "content": f"Spell the word: {word}"},
{"role": "assistant", "content": f"{word}:{word_letters}"}
]
conversation = {
"messages": messages,
}
return conversation
if __name__ == "__main__":
# preview the SpellingBee task, first 10 examples
task = SpellingBee()
for i in range(10):
ex = task.get_example(i)
print("=" * 100)
print(ex['messages'][0]['content'])
print("-" * 100)
# Assistant content is now a list of parts
assistant_parts = ex['messages'][1]['content']
for part in assistant_parts:
if part['type'] == 'text':
print(part['text'], end='')
elif part['type'] == 'python':
print(f"<<{part['text']}=", end='')
elif part['type'] == 'python_output':
print(f"{part['text']}>>", end='')
print()
print("-" * 100)
# # preview the SimpleSpelling task, first 10 examples
# task = SimpleSpelling()
# for i in range(10):
# ex = task.get_example(i)
# print("=" * 100)
# print(ex['messages'][0]['content'])
# print("-" * 100)
# print(ex['messages'][1]['content'])
# # also scrutinize the tokenization (last example only)
# from nanochat.tokenizer import get_tokenizer
# tokenizer = get_tokenizer()
# ids, mask = tokenizer.render_conversation(ex)
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))

335
uv.lock generated
View File

@@ -2,13 +2,32 @@ version = 1
revision = 3 revision = 3
requires-python = ">=3.10" requires-python = ">=3.10"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'", "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform != 'linux'", "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform != 'linux'", "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform != 'linux'", "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
] ]
conflicts = [[
{ package = "nanochat", extra = "cpu" },
{ package = "nanochat", extra = "gpu" },
]]
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@@ -26,7 +45,7 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "aiohappyeyeballs" }, { name = "aiohappyeyeballs" },
{ name = "aiosignal" }, { name = "aiosignal" },
{ name = "async-timeout", marker = "python_full_version < '3.11'" }, { name = "async-timeout", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "attrs" }, { name = "attrs" },
{ name = "frozenlist" }, { name = "frozenlist" },
{ name = "multidict" }, { name = "multidict" },
@@ -111,7 +130,7 @@ version = "1.4.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "frozenlist" }, { name = "frozenlist" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" }, { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" }
wheels = [ wheels = [
@@ -132,10 +151,10 @@ name = "anyio"
version = "4.10.0" version = "4.10.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "idna" }, { name = "idna" },
{ name = "sniffio" }, { name = "sniffio" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" }, { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/f1/b4/636b3b65173d3ce9a38ef5f0522789614e590dab6a8d505340a4efe4c567/anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6", size = 213252, upload-time = "2025-08-04T08:54:26.451Z" } sdist = { url = "https://files.pythonhosted.org/packages/f1/b4/636b3b65173d3ce9a38ef5f0522789614e590dab6a8d505340a4efe4c567/anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6", size = 213252, upload-time = "2025-08-04T08:54:26.451Z" }
wheels = [ wheels = [
@@ -238,7 +257,7 @@ name = "click"
version = "8.2.1" version = "8.2.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" }, { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" }
wheels = [ wheels = [
@@ -292,7 +311,7 @@ name = "exceptiongroup"
version = "1.3.0" version = "1.3.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.12' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
wheels = [ wheels = [
@@ -497,7 +516,7 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "filelock" }, { name = "filelock" },
{ name = "fsspec" }, { name = "fsspec" },
{ name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "packaging" }, { name = "packaging" },
{ name = "pyyaml" }, { name = "pyyaml" },
{ name = "requests" }, { name = "requests" },
@@ -602,7 +621,7 @@ name = "maturin"
version = "1.9.4" version = "1.9.4"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "tomli", marker = "python_full_version < '3.11'" }, { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/13/7c/b11b870fc4fd84de2099906314ce45488ae17be32ff5493519a6cddc518a/maturin-1.9.4.tar.gz", hash = "sha256:235163a0c99bc6f380fb8786c04fd14dcf6cd622ff295ea3de525015e6ac40cf", size = 213647, upload-time = "2025-08-27T11:37:57.079Z" } sdist = { url = "https://files.pythonhosted.org/packages/13/7c/b11b870fc4fd84de2099906314ce45488ae17be32ff5493519a6cddc518a/maturin-1.9.4.tar.gz", hash = "sha256:235163a0c99bc6f380fb8786c04fd14dcf6cd622ff295ea3de525015e6ac40cf", size = 213647, upload-time = "2025-08-27T11:37:57.079Z" }
wheels = [ wheels = [
@@ -635,7 +654,7 @@ name = "multidict"
version = "6.6.4" version = "6.6.4"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/69/7f/0652e6ed47ab288e3756ea9c0df8b14950781184d4bd7883f4d87dd41245/multidict-6.6.4.tar.gz", hash = "sha256:d2d4e4787672911b48350df02ed3fa3fffdc2f2e8ca06dd6afdf34189b76a9dd", size = 101843, upload-time = "2025-08-11T12:08:48.217Z" } sdist = { url = "https://files.pythonhosted.org/packages/69/7f/0652e6ed47ab288e3756ea9c0df8b14950781184d4bd7883f4d87dd41245/multidict-6.6.4.tar.gz", hash = "sha256:d2d4e4787672911b48350df02ed3fa3fffdc2f2e8ca06dd6afdf34189b76a9dd", size = 101843, upload-time = "2025-08-11T12:08:48.217Z" }
wheels = [ wheels = [
@@ -761,13 +780,26 @@ dependencies = [
{ name = "numpy" }, { name = "numpy" },
{ name = "psutil" }, { name = "psutil" },
{ name = "regex" }, { name = "regex" },
{ name = "setuptools" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "tokenizers" }, { name = "tokenizers" },
{ name = "torch" }, { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "torch", version = "2.9.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "uvicorn" }, { name = "uvicorn" },
{ name = "wandb" }, { name = "wandb" },
] ]
[package.optional-dependencies]
cpu = [
{ name = "torch", version = "2.9.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
gpu = [
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" } },
]
[package.dev-dependencies] [package.dev-dependencies]
dev = [ dev = [
{ name = "maturin" }, { name = "maturin" },
@@ -782,12 +814,16 @@ requires-dist = [
{ name = "numpy", specifier = "==1.26.4" }, { name = "numpy", specifier = "==1.26.4" },
{ name = "psutil", specifier = ">=7.1.0" }, { name = "psutil", specifier = ">=7.1.0" },
{ name = "regex", specifier = ">=2025.9.1" }, { name = "regex", specifier = ">=2025.9.1" },
{ name = "setuptools", specifier = ">=80.9.0" },
{ name = "tiktoken", specifier = ">=0.11.0" }, { name = "tiktoken", specifier = ">=0.11.0" },
{ name = "tokenizers", specifier = ">=0.22.0" }, { name = "tokenizers", specifier = ">=0.22.0" },
{ name = "torch", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "torch", specifier = ">=2.8.0" },
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
{ name = "torch", marker = "extra == 'gpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
{ name = "uvicorn", specifier = ">=0.36.0" }, { name = "uvicorn", specifier = ">=0.36.0" },
{ name = "wandb", specifier = ">=0.21.3" }, { name = "wandb", specifier = ">=0.21.3" },
] ]
provides-extras = ["cpu", "gpu"]
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [
@@ -800,8 +836,13 @@ name = "networkx"
version = "3.4.2" version = "3.4.2"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
resolution-markers = [ resolution-markers = [
"python_full_version < '3.11' and sys_platform == 'linux'", "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform != 'linux'", "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
] ]
sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" }
wheels = [ wheels = [
@@ -813,10 +854,20 @@ name = "networkx"
version = "3.5" version = "3.5"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'", "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform != 'linux'", "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform != 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
"python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
] ]
sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" }
wheels = [ wheels = [
@@ -860,7 +911,9 @@ name = "nvidia-cublas-cu12"
version = "12.8.4.1" version = "12.8.4.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" },
{ url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" },
{ url = "https://files.pythonhosted.org/packages/70/61/7d7b3c70186fb651d0fbd35b01dbfc8e755f69fd58f817f3d0f642df20c3/nvidia_cublas_cu12-12.8.4.1-py3-none-win_amd64.whl", hash = "sha256:47e9b82132fa8d2b4944e708049229601448aaad7e6f296f630f2d1a32de35af", size = 567544208, upload-time = "2025-03-07T01:53:30.535Z" },
] ]
[[package]] [[package]]
@@ -868,7 +921,9 @@ name = "nvidia-cuda-cupti-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" },
{ url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" },
{ url = "https://files.pythonhosted.org/packages/41/bc/83f5426095d93694ae39fe1311431b5d5a9bb82e48bf0dd8e19be2765942/nvidia_cuda_cupti_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:bb479dcdf7e6d4f8b0b01b115260399bf34154a1a2e9fe11c85c517d87efd98e", size = 7015759, upload-time = "2025-03-07T01:51:11.355Z" },
] ]
[[package]] [[package]]
@@ -877,6 +932,8 @@ version = "12.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" },
{ url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" },
{ url = "https://files.pythonhosted.org/packages/45/51/52a3d84baa2136cc8df15500ad731d74d3a1114d4c123e043cb608d4a32b/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:7a4b6b2904850fe78e0bd179c4b655c404d4bb799ef03ddc60804247099ae909", size = 73586838, upload-time = "2025-03-07T01:52:13.483Z" },
] ]
[[package]] [[package]]
@@ -884,7 +941,9 @@ name = "nvidia-cuda-runtime-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" },
{ url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" },
{ url = "https://files.pythonhosted.org/packages/30/a5/a515b7600ad361ea14bfa13fb4d6687abf500adc270f19e89849c0590492/nvidia_cuda_runtime_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:c0c6027f01505bfed6c3b21ec546f69c687689aad5f1a377554bc6ca4aa993a8", size = 944318, upload-time = "2025-03-07T01:51:01.794Z" },
] ]
[[package]] [[package]]
@@ -892,10 +951,12 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21" version = "9.10.2.21"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
{ url = "https://files.pythonhosted.org/packages/3d/90/0bd6e586701b3a890fd38aa71c387dab4883d619d6e5ad912ccbd05bfd67/nvidia_cudnn_cu12-9.10.2.21-py3-none-win_amd64.whl", hash = "sha256:c6288de7d63e6cf62988f0923f96dc339cea362decb1bf5b3141883392a7d65e", size = 692992268, upload-time = "2025-06-06T21:55:18.114Z" },
] ]
[[package]] [[package]]
@@ -903,10 +964,12 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83" version = "11.3.3.83"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
{ url = "https://files.pythonhosted.org/packages/7d/ec/ce1629f1e478bb5ccd208986b5f9e0316a78538dd6ab1d0484f012f8e2a1/nvidia_cufft_cu12-11.3.3.83-py3-none-win_amd64.whl", hash = "sha256:7a64a98ef2a7c47f905aaf8931b69a3a43f27c55530c698bb2ed7c75c0b42cb7", size = 192216559, upload-time = "2025-03-07T01:53:57.106Z" },
] ]
[[package]] [[package]]
@@ -915,6 +978,7 @@ version = "1.13.1.3"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" },
{ url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" },
] ]
[[package]] [[package]]
@@ -922,7 +986,9 @@ name = "nvidia-curand-cu12"
version = "10.3.9.90" version = "10.3.9.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" },
{ url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" },
{ url = "https://files.pythonhosted.org/packages/b9/75/70c05b2f3ed5be3bb30b7102b6eb78e100da4bbf6944fd6725c012831cab/nvidia_curand_cu12-10.3.9.90-py3-none-win_amd64.whl", hash = "sha256:f149a8ca457277da854f89cf282d6ef43176861926c7ac85b2a0fbd237c587ec", size = 62765309, upload-time = "2025-03-07T01:54:20.478Z" },
] ]
[[package]] [[package]]
@@ -930,12 +996,14 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90" version = "11.7.3.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
{ url = "https://files.pythonhosted.org/packages/13/c0/76ca8551b8a84146ffa189fec81c26d04adba4bc0dbe09cd6e6fd9b7de04/nvidia_cusolver_cu12-11.7.3.90-py3-none-win_amd64.whl", hash = "sha256:4a550db115fcabc4d495eb7d39ac8b58d4ab5d8e63274d3754df1c0ad6a22d34", size = 256720438, upload-time = "2025-03-07T01:54:39.898Z" },
] ]
[[package]] [[package]]
@@ -943,10 +1011,12 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93" version = "12.5.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
{ url = "https://files.pythonhosted.org/packages/62/07/f3b2ad63f8e3d257a599f422ae34eb565e70c41031aecefa3d18b62cabd1/nvidia_cusparse_cu12-12.5.8.93-py3-none-win_amd64.whl", hash = "sha256:9a33604331cb2cac199f2e7f5104dfbb8a5a898c367a53dfda9ff2acb6b6b4dd", size = 284937404, upload-time = "2025-03-07T01:55:07.742Z" },
] ]
[[package]] [[package]]
@@ -954,7 +1024,9 @@ name = "nvidia-cusparselt-cu12"
version = "0.7.1" version = "0.7.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" },
{ url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" },
{ url = "https://files.pythonhosted.org/packages/2f/d8/a6b0d0d0c2435e9310f3e2bb0d9c9dd4c33daef86aa5f30b3681defd37ea/nvidia_cusparselt_cu12-0.7.1-py3-none-win_amd64.whl", hash = "sha256:f67fbb5831940ec829c9117b7f33807db9f9678dc2a617fbe781cac17b4e1075", size = 271020911, upload-time = "2025-02-26T00:14:47.204Z" },
] ]
[[package]] [[package]]
@@ -962,6 +1034,7 @@ name = "nvidia-nccl-cu12"
version = "2.27.3" version = "2.27.3"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/4b/7b/8354b784cf73b0ba51e566b4baba3ddd44fe8288a3d39ef1e06cd5417226/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9ddf1a245abc36c550870f26d537a9b6087fb2e2e3d6e0ef03374c6fd19d984f", size = 322397768, upload-time = "2025-06-03T21:57:30.234Z" },
{ url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" }, { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" },
] ]
@@ -971,6 +1044,8 @@ version = "12.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" },
{ url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" },
{ url = "https://files.pythonhosted.org/packages/ed/d7/34f02dad2e30c31b10a51f6b04e025e5dd60e5f936af9045a9b858a05383/nvidia_nvjitlink_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:bd93fbeeee850917903583587f4fc3a4eafa022e34572251368238ab5e6bd67f", size = 268553710, upload-time = "2025-03-07T01:56:24.13Z" },
] ]
[[package]] [[package]]
@@ -978,7 +1053,9 @@ name = "nvidia-nvtx-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" },
{ url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
{ url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" },
] ]
[[package]] [[package]]
@@ -1334,13 +1411,13 @@ name = "pytest"
version = "8.4.2" version = "8.4.2"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" }, { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "iniconfig" }, { name = "iniconfig" },
{ name = "packaging" }, { name = "packaging" },
{ name = "pluggy" }, { name = "pluggy" },
{ name = "pygments" }, { name = "pygments" },
{ name = "tomli", marker = "python_full_version < '3.11'" }, { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" }
wheels = [ wheels = [
@@ -1561,7 +1638,7 @@ version = "0.48.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" }, { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" } sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" }
wheels = [ wheels = [
@@ -1684,30 +1761,38 @@ wheels = [
name = "torch" name = "torch"
version = "2.8.0+cu128" version = "2.8.0+cu128"
source = { registry = "https://download.pytorch.org/whl/cu128" } source = { registry = "https://download.pytorch.org/whl/cu128" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'",
"python_full_version >= '3.12' and sys_platform != 'linux'",
"python_full_version == '3.11.*' and sys_platform == 'linux'",
"python_full_version < '3.11' and sys_platform == 'linux'",
"python_full_version == '3.11.*' and sys_platform != 'linux'",
"python_full_version < '3.11' and sys_platform != 'linux'",
]
dependencies = [ dependencies = [
{ name = "filelock" }, { name = "filelock", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "fsspec" }, { name = "fsspec", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "jinja2" }, { name = "jinja2", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy" }, { name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions" }, { name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" },
] ]
wheels = [ wheels = [
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" },
@@ -1722,12 +1807,143 @@ wheels = [
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:970b4f4661fa7b44f6a7e6df65de7fc4a6fff2af610dc415c1d695ca5f1f37d2" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:970b4f4661fa7b44f6a7e6df65de7fc4a6fff2af610dc415c1d695ca5f1f37d2" },
] ]
[[package]]
name = "torch"
version = "2.9.0"
source = { registry = "https://download.pytorch.org/whl/cpu" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'darwin'",
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
"python_full_version < '3.11' and sys_platform == 'darwin'",
]
dependencies = [
{ name = "filelock", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "fsspec", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "jinja2", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:59484193b01299bf669520505a72b29d59a0028ae4c6d95f492938f186592208" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aa4483602586cc9a35d1cf33771a9977f05f642b9161518a289e36548a0b77c2" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:4de0ed8cbc457a506dbca40376e206a29efee10756a00f1f3404bf67ad737d04" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:259548471194ab63d7ea273873053a6e3cc23530c1510f01e9d7ad259187bbd0" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e24836d968b54ef4dfb05594001a61958711ac9224026291e4e3f92f83a6fd7f" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d8e2ab7f86010330bdcc39c8b2c795590cc75e37df4823cdaee2c98d6e3ff4a3" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a3e859039c985d8e3ea60d7a54ca7e97ea2ae15e31beced4f3260128a161bb01" },
]
[[package]]
name = "torch"
version = "2.9.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'",
"python_full_version >= '3.12' and sys_platform != 'linux'",
"python_full_version == '3.11.*' and sys_platform == 'linux'",
"python_full_version < '3.11' and sys_platform == 'linux'",
"python_full_version == '3.11.*' and sys_platform != 'linux'",
"python_full_version < '3.11' and sys_platform != 'linux'",
]
dependencies = [
{ name = "filelock", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "fsspec", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "jinja2", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/86/245c240d2138c17ed572c943c289056c2721abab70810d772c6bf5495b28/torch-2.9.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:030bbfe367379ae6a4ae4042b6c44da25383343b8b3c68abaa9c7231efbaf2dd", size = 104213554, upload-time = "2025-10-15T15:45:59.798Z" },
{ url = "https://files.pythonhosted.org/packages/58/1d/fd1e88ae0948825efcab7dd66d12bec23f05d4d38ed81573c8d453c14c06/torch-2.9.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:51cb63902182a78e90886e8068befd8ea102af4b00e420263591a3d70c7d3c6c", size = 899795167, upload-time = "2025-10-15T15:47:12.695Z" },
{ url = "https://files.pythonhosted.org/packages/63/5a/496197b45c14982bef4e079b24c61dc108e3ab0d0cc9718dba9f54f45a46/torch-2.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:3f6aad4d2f0ee2248bac25339d74858ff846c3969b27d14ac235821f055af83d", size = 109310314, upload-time = "2025-10-15T15:46:16.633Z" },
{ url = "https://files.pythonhosted.org/packages/58/b0/2b4e647b0fc706e88eb6c253d05511865578f5f67b55fad639bf3272a4a1/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:413e1654c9203733138858780e184d9fc59442f0b3b209e16f39354eb893db9b", size = 74452019, upload-time = "2025-10-15T15:46:04.296Z" },
{ url = "https://files.pythonhosted.org/packages/58/fe/334225e6330e672b36aef23d77451fa906ea12881570c08638a91331a212/torch-2.9.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c596708b5105d0b199215acf0c9be7c1db5f1680d88eddadf4b75a299259a677", size = 104230578, upload-time = "2025-10-15T15:46:08.182Z" },
{ url = "https://files.pythonhosted.org/packages/05/cc/49566caaa218872ec9a2912456f470ff92649894a4bc2e5274aa9ef87c4a/torch-2.9.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:51de31219c97c51cf4bf2be94d622e3deb5dcc526c6dc00e97c17eaec0fc1d67", size = 899815990, upload-time = "2025-10-15T15:48:03.336Z" },
{ url = "https://files.pythonhosted.org/packages/74/25/e9ab21d5925b642d008f139d4a3c9664fc9ee1faafca22913c080cc4c0a5/torch-2.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:dd515c70059afd95f48b8192733764c08ca37a1d19803af6401b5ecad7c8676e", size = 109313698, upload-time = "2025-10-15T15:46:12.425Z" },
{ url = "https://files.pythonhosted.org/packages/b3/b7/205ef3e94de636feffd64b28bb59a0dfac0771221201b9871acf9236f5ca/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:614a185e4986326d526a91210c8fc1397e76e8cfafa78baf6296a790e53a9eec", size = 74463678, upload-time = "2025-10-15T15:46:29.779Z" },
{ url = "https://files.pythonhosted.org/packages/d1/d3/3985739f3b8e88675127bf70f82b3a48ae083e39cda56305dbd90398fec0/torch-2.9.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e5f7af1dc4c0a7c4a260c2534f41ddaf209714f7c89145e644c44712fbd6b642", size = 104107898, upload-time = "2025-10-15T15:46:20.883Z" },
{ url = "https://files.pythonhosted.org/packages/a5/4b/f4bb2e6c25d0272f798cd6d7a04ed315da76cec68c602d87040c7847287f/torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:01cff95ecd9a212ea2f141db28acccdceb6a4c54f64e6c51091146f5e2a772c6", size = 899738273, upload-time = "2025-10-15T15:50:04.188Z" },
{ url = "https://files.pythonhosted.org/packages/66/11/c1c5ba6691cda6279087c35bd626536e4fd29521fe740abf5008377a9a02/torch-2.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:4582b162f541651f0cb184d3e291c05c2f556c7117c64a9873e2ee158d40062b", size = 109280887, upload-time = "2025-10-15T15:46:26.228Z" },
{ url = "https://files.pythonhosted.org/packages/dd/5f/b85bd8c05312d71de9402bf5868d217c38827cfd09d8f8514e5be128a52b/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:33f58e9a102a91259af289d50525c30323b5c9ae1d31322b6447c0814da68695", size = 74478983, upload-time = "2025-10-15T15:46:39.406Z" },
{ url = "https://files.pythonhosted.org/packages/c2/1c/90eb13833cdf4969ea9707586d7b57095c3b6e2b223a7256bf111689bcb8/torch-2.9.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c30a17fc83eeab346913e237c64b15b5ba6407fff812f6c541e322e19bc9ea0e", size = 104111330, upload-time = "2025-10-15T15:46:35.238Z" },
{ url = "https://files.pythonhosted.org/packages/0e/21/2254c54b8d523592c25ef4434769aa23e29b1e6bf5f4c0ad9e27bf442927/torch-2.9.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8f25033b8667b57857dfd01458fbf2a9e6a6df1f8def23aef0dc46292f6aa642", size = 899750243, upload-time = "2025-10-15T15:48:57.459Z" },
{ url = "https://files.pythonhosted.org/packages/b7/a5/5cb94fa4fd1e78223455c23c200f30f6dc10c6d4a2bcc8f6e7f2a2588370/torch-2.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:d037f1b4ffd25013be4a7bf3651a0a910c68554956c7b2c92ebe87c76475dece", size = 109284513, upload-time = "2025-10-15T15:46:45.061Z" },
{ url = "https://files.pythonhosted.org/packages/66/e8/fc414d8656250ee46120b44836ffbb3266343db424b3e18ca79ebbf69d4f/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e4e5b5cba837a2a8d1a497ba9a58dae46fa392593eaa13b871c42f71847503a5", size = 74830362, upload-time = "2025-10-15T15:46:48.983Z" },
{ url = "https://files.pythonhosted.org/packages/ed/5f/9474c98fc5ae0cd04b9466035428cd360e6611a86b8352a0fc2fa504acdc/torch-2.9.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:64693568f5dc4dbd5f880a478b1cea0201cc6b510d91d1bc54fea86ac5d1a637", size = 104144940, upload-time = "2025-10-15T15:47:29.076Z" },
{ url = "https://files.pythonhosted.org/packages/2d/5a/8e0c1cf57830172c109d4bd6be2708cabeaf550983eee7029291322447a0/torch-2.9.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:f8ed31ddd7d10bfb3fbe0b9fe01b1243577f13d75e6f4a0839a283915ce3791e", size = 899744054, upload-time = "2025-10-15T15:48:29.864Z" },
{ url = "https://files.pythonhosted.org/packages/6d/28/82c28b30fcb4b7c9cdd995763d18bbb830d6521356712faebbad92ffa61d/torch-2.9.0-cp313-cp313t-win_amd64.whl", hash = "sha256:eff527d4e4846e6f70d2afd8058b73825761203d66576a7e04ea2ecfebcb4ab8", size = 109517546, upload-time = "2025-10-15T15:47:33.395Z" },
{ url = "https://files.pythonhosted.org/packages/ff/c3/a91f96ec74347fa5fd24453fa514bc61c61ecc79196fa760b012a1873d96/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:f8877779cf56d1ce431a7636703bdb13307f5960bb1af49716d8b179225e0e6a", size = 74480732, upload-time = "2025-10-15T15:47:38.002Z" },
{ url = "https://files.pythonhosted.org/packages/5c/73/9f70af34b334a7e0ef496ceec96b7ec767bd778ea35385ce6f77557534d1/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7e614fae699838038d888729f82b687c03413c5989ce2a9481f9a7e7a396e0bb", size = 74433037, upload-time = "2025-10-15T15:47:41.894Z" },
{ url = "https://files.pythonhosted.org/packages/b7/84/37cf88625901934c97109e583ecc21777d21c6f54cda97a7e5bbad1ee2f2/torch-2.9.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:dfb5b8cd310ba3436c7e14e8b7833ef658cf3045e50d2bdaed23c8fc517065eb", size = 104116482, upload-time = "2025-10-15T15:47:46.266Z" },
{ url = "https://files.pythonhosted.org/packages/56/8e/ca8b17866943a8d4f4664d402ea84210aa274588b4c5d89918f5caa24eec/torch-2.9.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b3d29524993a478e46f5d598b249cd824b7ed98d7fba538bd9c4cde6c803948f", size = 899746916, upload-time = "2025-10-15T15:50:40.294Z" },
{ url = "https://files.pythonhosted.org/packages/43/65/3b17c0fbbdab6501c5b320a52a648628d0d44e7379f64e27d9eef701b6bf/torch-2.9.0-cp314-cp314-win_amd64.whl", hash = "sha256:71c7578984f5ec0eb645eb4816ac8435fcf3e3e2ae1901bcd2f519a9cafb5125", size = 109275151, upload-time = "2025-10-15T15:49:20.715Z" },
{ url = "https://files.pythonhosted.org/packages/83/36/74f8c051f785500396e42f93542422422dfd874a174f21f8d955d36e5d64/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:71d9309aee457bbe0b164bce2111cd911c4ed4e847e65d5077dbbcd3aba6befc", size = 74823353, upload-time = "2025-10-15T15:49:16.59Z" },
{ url = "https://files.pythonhosted.org/packages/62/51/dc3b4e2f9ba98ae27238f0153ca098bf9340b2dafcc67fde645d496dfc2a/torch-2.9.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c08fb654d783899e204a32cca758a7ce8a45b2d78eeb89517cc937088316f78e", size = 104140340, upload-time = "2025-10-15T15:50:19.67Z" },
{ url = "https://files.pythonhosted.org/packages/c0/8d/b00657f8141ac16af7bb6cda2e67de18499a3263b78d516b9a93fcbc98e3/torch-2.9.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ec8feb0099b2daa5728fbc7abb0b05730fd97e0f359ff8bda09865aaa7bd7d4b", size = 899731750, upload-time = "2025-10-15T15:49:36.673Z" },
{ url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" },
]
[[package]]
name = "torch"
version = "2.9.0+cpu"
source = { registry = "https://download.pytorch.org/whl/cpu" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'",
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'",
"python_full_version == '3.11.*' and sys_platform == 'linux'",
"python_full_version < '3.11' and sys_platform == 'linux'",
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux'",
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'",
]
dependencies = [
{ name = "filelock", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "fsspec", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "jinja2", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:b224792ea567b52c7f1ce1d789567f6920e06fd3b339fa1e1b05948845f783ad" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:bd2a257e670ede9fc01c6d76dccdc473040913b8e9328169bf177dbdc38e2484" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:96f3f7aa4eb9e7fc5af8a722eaf1e5e32e3039dbafe817178d7b90a8566be32d" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:da77341ccaba31762d9238b0942c165c4582a26818f3045b052b39cebdd7ad9d" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:add3e93ecc1eeaa6853f6a973ce60ffb3cb14ed2e80f5055e139b09385dce0a7" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:389e1e0b8083fd355f7caf5ba82356b5e01c318998bd575dbf2285a0d8137089" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:5ce3d01aef91dc078fbb121814e556d55bc886d303efaf42c4fe67e411f5f9ad" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3a651434ae1248b0568c12b5f9e3acc8942eb28378d9d04a79302938b68c6f24" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:28f6eb31b08180a5c5e98d5bc14eef6909c9f5a1dbff9632c3e02a8773449349" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:e438061b87ec7dd6018fca9f975219889aa0a3f6cdc3ea10dd0ae2bc7f1c47ce" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:eb13ff1c34e338d722e76a4fd83b8d282782505bd1b99af4b3c32da66eba6eb4" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:be4438d8dad7f0d5a5e54f0feef8a893446894ec87f102bb1d82dcc4518542e4" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6c9b217584400963d5b4daddb3711ec7a3778eab211e18654fba076cce3b8682" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:728372e3f58c5826445f677746e5311c1935c1a7c59599f73a49ded850e038e8" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:95e56c26f919fbb98f16e7a0b87af494b893f9da9a65a020f17a01c13e520a81" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:6c777160288b08555820781ae0f3a2c67a59bd24b065e88ca1ec20e2f9dc8ac7" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:528fd338311f31c9fb18038cafd00e6eae0bf5ad5577521701acb62510753d18" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:d572863990e7d2762b547735ef589f6350d9eb4e441d38753a1c33636698cf4c" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:44aadb735774d4a99525d2ec29126b23016c44a07b02ce6c237dfa61a223dd52" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b355e07b7f0c369cb031adfcbff5c37a609abcea091b918a39886412afd2e07d" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-win_amd64.whl", hash = "sha256:c2698999361d73c2d25d7cc8a787130188d49b183abb18b554228daa102e1594" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:fa0d1373d04b30ff8f12d542135d292f1a1ddb7c0d852a3d487a320360e5dab9" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:2f49bb57a5fe0dc7f8e73ea9e5d36ebda2ea25b8a714a788f0fc2fc47d20a830" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-win_amd64.whl", hash = "sha256:3a60d1ecf27a9cce839b3aa665b26f0af1b1007b9c9f1e7f597f6b7bdf107617" },
]
[[package]] [[package]]
name = "tqdm" name = "tqdm"
version = "4.67.1" version = "4.67.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" }, { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
wheels = [ wheels = [
@@ -1739,7 +1955,7 @@ name = "triton"
version = "3.4.0" version = "3.4.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "setuptools", marker = "sys_platform == 'linux'" }, { name = "setuptools", marker = "extra == 'extra-8-nanochat-gpu'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },
@@ -1795,7 +2011,7 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "click" }, { name = "click" },
{ name = "h11" }, { name = "h11" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/ef/5e/f0cd46063a02fd8515f0e880c37d2657845b7306c16ce6c4ffc44afd9036/uvicorn-0.36.0.tar.gz", hash = "sha256:527dc68d77819919d90a6b267be55f0e76704dca829d34aea9480be831a9b9d9", size = 80032, upload-time = "2025-09-20T01:07:14.418Z" } sdist = { url = "https://files.pythonhosted.org/packages/ef/5e/f0cd46063a02fd8515f0e880c37d2657845b7306c16ce6c4ffc44afd9036/uvicorn-0.36.0.tar.gz", hash = "sha256:527dc68d77819919d90a6b267be55f0e76704dca829d34aea9480be831a9b9d9", size = 80032, upload-time = "2025-09-20T01:07:14.418Z" }
wheels = [ wheels = [
@@ -2002,4 +2218,3 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
] ]