mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
delete the configurator in favor of argparse and clean up a lot of kwarg details to make them more consistent across all scripts
This commit is contained in:
@@ -140,7 +140,6 @@ python -m pytest tests/test_engine.py -v -s
|
||||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||
│ ├── common.py # Misc small utilities, quality of life
|
||||
│ ├── configurator.py # A superior alternative to argparse
|
||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||
│ ├── dataloader.py # Tokenizing Distributed Data Loader
|
||||
│ ├── dataset.py # Download/read utils for pretraining data
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
"""
|
||||
Poor Man's Configurator. Probably a terrible idea. Example usage:
|
||||
$ python train.py config/override_file.py --batch_size=32
|
||||
this will first run config/override_file.py, then override batch_size to 32
|
||||
|
||||
The code in this file will be run as follows from e.g. train.py:
|
||||
>>> exec(open('configurator.py').read())
|
||||
|
||||
So it's not a Python module, it's just shuttling this code away from train.py
|
||||
The code in this script then overrides the globals()
|
||||
|
||||
I know people are not going to love this, I just really dislike configuration
|
||||
complexity and having to prepend config. to every single variable. If someone
|
||||
comes up with a better simple Python solution I am all ears.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from ast import literal_eval
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
for arg in sys.argv[1:]:
|
||||
if '=' not in arg:
|
||||
# assume it's the name of a config file
|
||||
assert not arg.startswith('--')
|
||||
config_file = arg
|
||||
print0(f"Overriding config with {config_file}:")
|
||||
with open(config_file) as f:
|
||||
print0(f.read())
|
||||
exec(open(config_file).read())
|
||||
else:
|
||||
# assume it's a --key=value argument
|
||||
assert arg.startswith('--')
|
||||
key, val = arg.split('=')
|
||||
key = key[2:]
|
||||
if key in globals():
|
||||
try:
|
||||
# attempt to eval it it (e.g. if bool, number, or etc)
|
||||
attempt = literal_eval(val)
|
||||
except (SyntaxError, ValueError):
|
||||
# if that goes wrong, just use the string
|
||||
attempt = val
|
||||
# ensure the types match ok
|
||||
if globals()[key] is not None:
|
||||
attempt_type = type(attempt)
|
||||
default_type = type(globals()[key])
|
||||
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
|
||||
# cross fingers
|
||||
print0(f"Overriding: {key} = {attempt}")
|
||||
globals()[key] = attempt
|
||||
else:
|
||||
raise ValueError(f"Unknown config key: {key}")
|
||||
@@ -167,7 +167,7 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
||||
assert temperature >= 0.0, "temperature must be non-negative"
|
||||
if temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
if top_k is not None:
|
||||
if top_k is not None and top_k > 0:
|
||||
k = min(top_k, logits.size(-1))
|
||||
vals, idx = torch.topk(logits, k, dim=-1)
|
||||
vals = vals / temperature
|
||||
|
||||
@@ -6,7 +6,7 @@ Loads a checkpoint, and:
|
||||
Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
@@ -16,29 +16,30 @@ from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Configuration
|
||||
device_batch_size = 32
|
||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split_tokens", type=int, default=20*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = split_tokens // tokens_per_step
|
||||
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
||||
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
loader = tokenizing_distributed_data_loader(args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
|
||||
@@ -13,6 +13,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -30,46 +31,46 @@ from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target_param_data_ratio", type=int, default=20, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="gradient clipping value (0.0 = disabled)")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown_ratio", type=float, default=0.2, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps")
|
||||
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
@@ -77,8 +78,8 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
@@ -87,8 +88,8 @@ vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_layers = args.depth
|
||||
model_dim = args.depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
@@ -98,19 +99,19 @@ print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
@@ -120,12 +121,12 @@ model.init_weights() # All tensors get initialized
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
resuming = resume_from_step != -1
|
||||
resuming = args.resume_from_step != -1
|
||||
if resuming:
|
||||
print0(f"Resuming optimization from step {resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
print0(f"Resuming optimization from step {args.resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
@@ -137,28 +138,29 @@ num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
||||
if num_iterations > 0:
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
num_iterations = args.num_iterations
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif target_flops > 0:
|
||||
elif args.target_flops > 0:
|
||||
# calculate the number of iterations from the target flops
|
||||
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif target_param_data_ratio > 0:
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio
|
||||
target_tokens = target_param_data_ratio * num_params
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
target_tokens = args.target_param_data_ratio * num_params
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
@@ -170,8 +172,8 @@ if resuming:
|
||||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -179,15 +181,15 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
elif it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
else:
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
@@ -215,13 +217,13 @@ else:
|
||||
# Training loop
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
if last_step or step % args.eval_every == 0:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
@@ -238,10 +240,10 @@ while True:
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
results = {}
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
@@ -253,7 +255,7 @@ while True:
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
if master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
@@ -273,7 +275,7 @@ while True:
|
||||
model.train()
|
||||
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
@@ -284,8 +286,8 @@ while True:
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"device_batch_size": args.device_batch_size,
|
||||
"max_seq_len": args.max_seq_len,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
@@ -313,9 +315,9 @@ while True:
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping
|
||||
grad_clip_enabled = grad_clip > 0.0
|
||||
grad_clip_enabled = args.grad_clip > 0.0
|
||||
if grad_clip_enabled:
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), args.grad_clip)
|
||||
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
@@ -338,8 +340,8 @@ while True:
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
@@ -378,11 +380,11 @@ get_report().log(section="Base model training", data=[
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
||||
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
"warmdown_ratio": warmdown_ratio,
|
||||
"final_lr_frac": final_lr_frac,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
"final_lr_frac": args.final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
|
||||
@@ -16,57 +16,69 @@ python -m scripts.chat_rl
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import itertools
|
||||
import re
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
max_new_tokens = 256
|
||||
temperature = 1.0
|
||||
top_k = 50 # TODO: try None?
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.05
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
# Batch sizes / sampling
|
||||
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
|
||||
# Generation
|
||||
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
# Evaluation / checkpointing
|
||||
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -74,7 +86,7 @@ engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
train_task = GSM8K(subset="main", split="train")
|
||||
val_task = GSM8K(subset="main", split="test")
|
||||
num_steps = (len(train_task) // examples_per_step) * num_epochs
|
||||
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
|
||||
print0(f"Calculated number of steps: {num_steps}")
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -95,16 +107,16 @@ def get_batch():
|
||||
model.eval() # ensure the model is in eval mode
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=device_batch_size,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
@@ -191,16 +203,16 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
@@ -209,9 +221,9 @@ def get_lr_multiplier(it):
|
||||
return lrm
|
||||
|
||||
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
|
||||
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Calculated examples per rank: {examples_per_rank}")
|
||||
|
||||
# Kick off the training loop
|
||||
@@ -219,22 +231,22 @@ batch_iterator = get_batch()
|
||||
for step in range(num_steps):
|
||||
|
||||
# Evaluate the model once in a while and log to wandb
|
||||
if step % eval_every == 0:
|
||||
if step % args.eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, device_batch_size + 1):
|
||||
for k in range(1, args.device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**log_passk,
|
||||
@@ -249,11 +261,11 @@ for step in range(num_steps):
|
||||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
assert inputs_all.size(0) % device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // device_batch_size
|
||||
assert inputs_all.size(0) % args.device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // args.device_batch_size
|
||||
for pass_idx in range(num_passes):
|
||||
# Pluck out the batch for this pass
|
||||
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
|
||||
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
||||
inputs = inputs_all[b0:b1]
|
||||
targets = targets_all[b0:b1]
|
||||
rewards = rewards_all[b0:b1]
|
||||
@@ -306,10 +318,10 @@ for step in range(num_steps):
|
||||
})
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
|
||||
@@ -9,6 +9,7 @@ Or torchrun for training:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
@@ -31,49 +32,51 @@ from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# input model options
|
||||
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.02
|
||||
# evaluation and logging there of
|
||||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config, save_code=True)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
@@ -127,34 +130,36 @@ def sft_data_generator(dataset, batch_size):
|
||||
yield collate_and_yield(batch)
|
||||
batch = []
|
||||
|
||||
examples_per_step = device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {target_examples_per_step}")
|
||||
print0(f"Device batch size: {device_batch_size}")
|
||||
examples_per_step = args.device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {args.target_examples_per_step}")
|
||||
print0(f"Device batch size: {args.device_batch_size}")
|
||||
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
|
||||
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
assert args.target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = args.target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
if num_iterations == -1:
|
||||
if args.num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
assert args.num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // args.target_examples_per_step) * args.num_epochs
|
||||
else:
|
||||
num_iterations = args.num_iterations
|
||||
train_loader = sft_data_generator(train_ds, batch_size=args.device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_batch_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -171,11 +176,11 @@ for step in range(num_iterations):
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
if last_step or step % args.eval_every == 0:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
for _ in range(args.eval_steps):
|
||||
val_inputs, val_targets = next(val_loader)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
@@ -192,13 +197,13 @@ for step in range(num_iterations):
|
||||
model.train()
|
||||
|
||||
# evaluate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
|
||||
model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
@@ -250,7 +255,7 @@ for step in range(num_iterations):
|
||||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
|
||||
@@ -9,6 +9,7 @@ Or torchrun for training:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
@@ -31,65 +32,75 @@ from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Midtrain the model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=args.run, config=user_config)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
@@ -120,7 +131,7 @@ def mid_data_generator(split):
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
needed_tokens = args.device_batch_size * args.max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
@@ -139,18 +150,18 @@ def mid_data_generator(split):
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if 0 < num_iterations <= it and split == "train":
|
||||
if 0 < args.num_iterations <= it and split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
if args.num_iterations > 0:
|
||||
approx_progress = it / args.num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
@@ -179,7 +190,7 @@ ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
||||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
@@ -188,10 +199,10 @@ while True:
|
||||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
@@ -206,8 +217,8 @@ while True:
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
if master_process and last_step and not args.dry_run:
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
@@ -218,7 +229,7 @@ while True:
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"sequence_len": args.max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
@@ -268,8 +279,8 @@ while True:
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
@@ -293,7 +304,7 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
if not dry_run:
|
||||
if not args.dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
|
||||
404
uv.lock
generated
404
uv.lock
generated
@@ -746,21 +746,22 @@ dependencies = [
|
||||
{ name = "setuptools" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wandb" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
cpu = [
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
gpu = [
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" } },
|
||||
{ name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" } },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
@@ -779,8 +780,8 @@ requires-dist = [
|
||||
{ name = "tiktoken", specifier = ">=0.11.0" },
|
||||
{ name = "tokenizers", specifier = ">=0.22.0" },
|
||||
{ name = "torch", specifier = ">=2.8.0" },
|
||||
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
|
||||
{ name = "torch", marker = "extra == 'gpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
|
||||
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
|
||||
{ name = "torch", marker = "extra == 'gpu'", specifier = ">=2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
]
|
||||
@@ -909,7 +910,7 @@ name = "nvidia-cudnn-cu12"
|
||||
version = "9.10.2.21"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
|
||||
@@ -922,7 +923,7 @@ name = "nvidia-cufft-cu12"
|
||||
version = "11.3.3.83"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
|
||||
@@ -954,9 +955,9 @@ name = "nvidia-cusolver-cu12"
|
||||
version = "11.7.3.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
|
||||
@@ -969,7 +970,7 @@ name = "nvidia-cusparse-cu12"
|
||||
version = "12.5.8.93"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
|
||||
@@ -989,11 +990,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nccl-cu12"
|
||||
version = "2.27.3"
|
||||
version = "2.27.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/7b/8354b784cf73b0ba51e566b4baba3ddd44fe8288a3d39ef1e06cd5417226/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9ddf1a245abc36c550870f26d537a9b6087fb2e2e3d6e0ef03374c6fd19d984f", size = 322397768, upload-time = "2025-06-03T21:57:30.234Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1c/857979db0ef194ca5e21478a0612bcdbbe59458d7694361882279947b349/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:31432ad4d1fb1004eb0c56203dc9bc2178a1ba69d1d9e02d64a6938ab5e40e7a", size = 322400625, upload-time = "2025-06-26T04:11:04.496Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1006,6 +1007,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/d7/34f02dad2e30c31b10a51f6b04e025e5dd60e5f936af9045a9b858a05383/nvidia_nvjitlink_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:bd93fbeeee850917903583587f4fc3a4eafa022e34572251368238ab5e6bd67f", size = 268553710, upload-time = "2025-03-07T01:56:24.13Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nvshmem-cu12"
|
||||
version = "3.3.20"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/92/9d/3dd98852568fb845ec1f7902c90a22b240fe1cbabda411ccedf2fd737b7b/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b0b960da3842212758e4fa4696b94f129090b30e5122fea3c5345916545cff0", size = 124484616, upload-time = "2025-08-04T20:24:59.172Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d00f26d3f9b2e3c3065be895e3059d6479ea5c638a3f38c9fec49b1b9dd7c1e5", size = 124657145, upload-time = "2025-08-04T20:25:19.995Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nvtx-cu12"
|
||||
version = "12.8.90"
|
||||
@@ -1752,106 +1762,40 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.8.0+cu128"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu128" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "fsspec", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "jinja2", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:43938e9a174c90e5eb9e906532b2f1e21532bbfa5a61b65193b4f54714d34f9e" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:039b9dcdd6bdbaa10a8a5cd6be22c4cb3e3589a341e5f904cbb571ca28f55bed" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:34c55443aafd31046a7963b63d30bc3b628ee4a704f826796c865fdfd05bb596" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4354fc05bb79b208d6995a04ca1ceef6a9547b1c4334435574353d381c55087c" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:0ad925202387f4e7314302a1b4f8860fa824357f9b1466d7992bf276370ebcff" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:3a852369a38dec343d45ecd0bc3660f79b88a23e0c878d18707f7c13bf49538f" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:9e20646802b7fc295c1f8b45fefcfc9fb2e4ec9cbe8593443cd2b9cc307c8405" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4295a22d69408e93d25f51e8d5d579345b6b802383e9414b0f3853ed433d53ae" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:970b4f4661fa7b44f6a7e6df65de7fc4a6fff2af610dc415c1d695ca5f1f37d2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:59484193b01299bf669520505a72b29d59a0028ae4c6d95f492938f186592208" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aa4483602586cc9a35d1cf33771a9977f05f642b9161518a289e36548a0b77c2" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:4de0ed8cbc457a506dbca40376e206a29efee10756a00f1f3404bf67ad737d04" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:259548471194ab63d7ea273873053a6e3cc23530c1510f01e9d7ad259187bbd0" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e24836d968b54ef4dfb05594001a61958711ac9224026291e4e3f92f83a6fd7f" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d8e2ab7f86010330bdcc39c8b2c795590cc75e37df4823cdaee2c98d6e3ff4a3" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a3e859039c985d8e3ea60d7a54ca7e97ea2ae15e31beced4f3260128a161bb01" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "filelock", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvshmem-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "triton", version = "3.5.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/86/245c240d2138c17ed572c943c289056c2721abab70810d772c6bf5495b28/torch-2.9.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:030bbfe367379ae6a4ae4042b6c44da25383343b8b3c68abaa9c7231efbaf2dd", size = 104213554, upload-time = "2025-10-15T15:45:59.798Z" },
|
||||
@@ -1886,7 +1830,86 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0+cpu"
|
||||
version = "2.9.1"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp310-none-macosx_11_0_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp311-none-macosx_11_0_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp312-none-macosx_11_0_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-cp313t-macosx_11_0_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-none-macosx_11_0_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314t-macosx_11_0_arm64.whl" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/56/9577683b23072075ed2e40d725c52c2019d71a972fab8e083763da8e707e/torch-2.9.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:1cc208435f6c379f9b8fdfd5ceb5be1e3b72a6bdf1cb46c0d2812aa73472db9e", size = 104207681, upload-time = "2025-11-12T15:19:56.48Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/45/be5a74f221df8f4b609b78ff79dc789b0cc9017624544ac4dd1c03973150/torch-2.9.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:9fd35c68b3679378c11f5eb73220fdcb4e6f4592295277fbb657d31fd053237c", size = 899794036, upload-time = "2025-11-12T15:21:01.886Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/67/95/a581e8a382596b69385a44bab2733f1273d45c842f5d4a504c0edc3133b6/torch-2.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:2af70e3be4a13becba4655d6cc07dcfec7ae844db6ac38d6c1dafeb245d17d65", size = 110969861, upload-time = "2025-11-12T15:21:30.145Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/51/1756dc128d2bf6ea4e0a915cb89ea5e730315ff33d60c1ff56fd626ba3eb/torch-2.9.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a83b0e84cc375e3318a808d032510dde99d696a85fe9473fc8575612b63ae951", size = 74452222, upload-time = "2025-11-12T15:20:46.223Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/db/c064112ac0089af3d2f7a2b5bfbabf4aa407a78b74f87889e524b91c5402/torch-2.9.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:62b3fd888277946918cba4478cf849303da5359f0fb4e3bfb86b0533ba2eaf8d", size = 104220430, upload-time = "2025-11-12T15:20:31.705Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/be/76eaa36c9cd032d3b01b001e2c5a05943df75f26211f68fae79e62f87734/torch-2.9.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d033ff0ac3f5400df862a51bdde9bad83561f3739ea0046e68f5401ebfa67c1b", size = 899821446, upload-time = "2025-11-12T15:20:15.544Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/cc/7a2949e38dfe3244c4df21f0e1c27bce8aedd6c604a587dd44fc21017cb4/torch-2.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:0d06b30a9207b7c3516a9e0102114024755a07045f0c1d2f2a56b1819ac06bcb", size = 110973074, upload-time = "2025-11-12T15:21:39.958Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/ce/7d251155a783fb2c1bb6837b2b7023c622a2070a0a72726ca1df47e7ea34/torch-2.9.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:52347912d868653e1528b47cafaf79b285b98be3f4f35d5955389b1b95224475", size = 74463887, upload-time = "2025-11-12T15:20:36.611Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/27/07c645c7673e73e53ded71705045d6cb5bae94c4b021b03aa8d03eee90ab/torch-2.9.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:da5f6f4d7f4940a173e5572791af238cb0b9e21b1aab592bd8b26da4c99f1cd6", size = 104126592, upload-time = "2025-11-12T15:20:41.62Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/17/e377a460603132b00760511299fceba4102bd95db1a0ee788da21298ccff/torch-2.9.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:27331cd902fb4322252657f3902adf1c4f6acad9dcad81d8df3ae14c7c4f07c4", size = 899742281, upload-time = "2025-11-12T15:22:17.602Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/1a/64f5769025db846a82567fa5b7d21dba4558a7234ee631712ee4771c436c/torch-2.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:81a285002d7b8cfd3fdf1b98aa8df138d41f1a8334fd9ea37511517cedf43083", size = 110940568, upload-time = "2025-11-12T15:21:18.689Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/ab/07739fd776618e5882661d04c43f5b5586323e2f6a2d7d84aac20d8f20bd/torch-2.9.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:c0d25d1d8e531b8343bea0ed811d5d528958f1dcbd37e7245bc686273177ad7e", size = 74479191, upload-time = "2025-11-12T15:21:25.816Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/60/8fc5e828d050bddfab469b3fe78e5ab9a7e53dda9c3bdc6a43d17ce99e63/torch-2.9.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c29455d2b910b98738131990394da3e50eea8291dfeb4b12de71ecf1fdeb21cb", size = 104135743, upload-time = "2025-11-12T15:21:34.936Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/b7/6d3f80e6918213babddb2a37b46dbb14c15b14c5f473e347869a51f40e1f/torch-2.9.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:524de44cd13931208ba2c4bde9ec7741fd4ae6bfd06409a604fc32f6520c2bc9", size = 899749493, upload-time = "2025-11-12T15:24:36.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/47/c7843d69d6de8938c1cbb1eba426b1d48ddf375f101473d3e31a5fc52b74/torch-2.9.1-cp313-cp313-win_amd64.whl", hash = "sha256:545844cc16b3f91e08ce3b40e9c2d77012dd33a48d505aed34b7740ed627a1b2", size = 110944162, upload-time = "2025-11-12T15:21:53.151Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/0e/2a37247957e72c12151b33a01e4df651d9d155dd74d8cfcbfad15a79b44a/torch-2.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5be4bf7496f1e3ffb1dd44b672adb1ac3f081f204c5ca81eba6442f5f634df8e", size = 74830751, upload-time = "2025-11-12T15:21:43.792Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/f7/7a18745edcd7b9ca2381aa03353647bca8aace91683c4975f19ac233809d/torch-2.9.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:30a3e170a84894f3652434b56d59a64a2c11366b0ed5776fab33c2439396bf9a", size = 104142929, upload-time = "2025-11-12T15:21:48.319Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/dd/f1c0d879f2863ef209e18823a988dc7a1bf40470750e3ebe927efdb9407f/torch-2.9.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8301a7b431e51764629208d0edaa4f9e4c33e6df0f2f90b90e261d623df6a4e2", size = 899748978, upload-time = "2025-11-12T15:23:04.568Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/9f/6986b83a53b4d043e36f3f898b798ab51f7f20fdf1a9b01a2720f445043d/torch-2.9.1-cp313-cp313t-win_amd64.whl", hash = "sha256:2e1c42c0ae92bf803a4b2409fdfed85e30f9027a66887f5e7dcdbc014c7531db", size = 111176995, upload-time = "2025-11-12T15:22:01.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/60/71c698b466dd01e65d0e9514b5405faae200c52a76901baf6906856f17e4/torch-2.9.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:2c14b3da5df416cf9cb5efab83aa3056f5b8cd8620b8fde81b4987ecab730587", size = 74480347, upload-time = "2025-11-12T15:21:57.648Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/50/c4b5112546d0d13cc9eaa1c732b823d676a9f49ae8b6f97772f795874a03/torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1edee27a7c9897f4e0b7c14cfc2f3008c571921134522d5b9b5ec4ebbc69041a", size = 74433245, upload-time = "2025-11-12T15:22:39.027Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/c9/2628f408f0518b3bae49c95f5af3728b6ab498c8624ab1e03a43dd53d650/torch-2.9.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:19d144d6b3e29921f1fc70503e9f2fc572cde6a5115c0c0de2f7ca8b1483e8b6", size = 104134804, upload-time = "2025-11-12T15:22:35.222Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/fc/5bc91d6d831ae41bf6e9e6da6468f25330522e92347c9156eb3f1cb95956/torch-2.9.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:c432d04376f6d9767a9852ea0def7b47a7bbc8e7af3b16ac9cf9ce02b12851c9", size = 899747132, upload-time = "2025-11-12T15:23:36.068Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/5d/e8d4e009e52b6b2cf1684bde2a6be157b96fb873732542fb2a9a99e85a83/torch-2.9.1-cp314-cp314-win_amd64.whl", hash = "sha256:d187566a2cdc726fc80138c3cdb260970fab1c27e99f85452721f7759bbd554d", size = 110934845, upload-time = "2025-11-12T15:22:48.367Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/b2/2d15a52516b2ea3f414643b8de68fa4cb220d3877ac8b1028c83dc8ca1c4/torch-2.9.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cb10896a1f7fedaddbccc2017ce6ca9ecaaf990f0973bdfcf405439750118d2c", size = 74823558, upload-time = "2025-11-12T15:22:43.392Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/5c/5b2e5d84f5b9850cd1e71af07524d8cbb74cba19379800f1f9f7c997fc70/torch-2.9.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:0a2bd769944991c74acf0c4ef23603b9c777fdf7637f115605a4b2d8023110c7", size = 104145788, upload-time = "2025-11-12T15:23:52.109Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/8c/3da60787bcf70add986c4ad485993026ac0ca74f2fc21410bc4eb1bb7695/torch-2.9.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:07c8a9660bc9414c39cac530ac83b1fb1b679d7155824144a40a54f4a47bfa73", size = 899735500, upload-time = "2025-11-12T15:24:08.788Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/2b/f7818f6ec88758dfd21da46b6cd46af9d1b3433e53ddbb19ad1e0da17f9b/torch-2.9.1-cp314-cp314t-win_amd64.whl", hash = "sha256:c88d3299ddeb2b35dcc31753305612db485ab6f1823e37fb29451c8b2732b87e", size = 111163659, upload-time = "2025-11-12T15:23:20.009Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.1+cpu"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
@@ -1907,30 +1930,92 @@ dependencies = [
|
||||
{ name = "typing-extensions", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:b224792ea567b52c7f1ce1d789567f6920e06fd3b339fa1e1b05948845f783ad" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:bd2a257e670ede9fc01c6d76dccdc473040913b8e9328169bf177dbdc38e2484" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:96f3f7aa4eb9e7fc5af8a722eaf1e5e32e3039dbafe817178d7b90a8566be32d" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:da77341ccaba31762d9238b0942c165c4582a26818f3045b052b39cebdd7ad9d" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:add3e93ecc1eeaa6853f6a973ce60ffb3cb14ed2e80f5055e139b09385dce0a7" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:389e1e0b8083fd355f7caf5ba82356b5e01c318998bd575dbf2285a0d8137089" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:5ce3d01aef91dc078fbb121814e556d55bc886d303efaf42c4fe67e411f5f9ad" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3a651434ae1248b0568c12b5f9e3acc8942eb28378d9d04a79302938b68c6f24" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:28f6eb31b08180a5c5e98d5bc14eef6909c9f5a1dbff9632c3e02a8773449349" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:e438061b87ec7dd6018fca9f975219889aa0a3f6cdc3ea10dd0ae2bc7f1c47ce" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:eb13ff1c34e338d722e76a4fd83b8d282782505bd1b99af4b3c32da66eba6eb4" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:be4438d8dad7f0d5a5e54f0feef8a893446894ec87f102bb1d82dcc4518542e4" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6c9b217584400963d5b4daddb3711ec7a3778eab211e18654fba076cce3b8682" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:728372e3f58c5826445f677746e5311c1935c1a7c59599f73a49ded850e038e8" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:95e56c26f919fbb98f16e7a0b87af494b893f9da9a65a020f17a01c13e520a81" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:6c777160288b08555820781ae0f3a2c67a59bd24b065e88ca1ec20e2f9dc8ac7" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:528fd338311f31c9fb18038cafd00e6eae0bf5ad5577521701acb62510753d18" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:d572863990e7d2762b547735ef589f6350d9eb4e441d38753a1c33636698cf4c" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:44aadb735774d4a99525d2ec29126b23016c44a07b02ce6c237dfa61a223dd52" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b355e07b7f0c369cb031adfcbff5c37a609abcea091b918a39886412afd2e07d" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-win_amd64.whl", hash = "sha256:c2698999361d73c2d25d7cc8a787130188d49b183abb18b554228daa102e1594" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:fa0d1373d04b30ff8f12d542135d292f1a1ddb7c0d852a3d487a320360e5dab9" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:2f49bb57a5fe0dc7f8e73ea9e5d36ebda2ea25b8a714a788f0fc2fc47d20a830" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-win_amd64.whl", hash = "sha256:3a60d1ecf27a9cce839b3aa665b26f0af1b1007b9c9f1e7f597f6b7bdf107617" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_arm64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-win_amd64.whl" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.1+cu128"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu128" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "fsspec", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "jinja2", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-curand-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvshmem-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "triton", version = "3.5.1", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1947,17 +2032,54 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "triton"
|
||||
version = "3.4.0"
|
||||
version = "3.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "setuptools", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780, upload-time = "2025-07-30T19:58:51.171Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/22/507b6f58a35e05e84381630b2dc2a3cee1a7a2a7eaf4cba857c638a18a24/triton-3.5.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6f90de6a6566bb619b4c0adc9855729e1b1b5e26533fca1bf6206e96b6d277a3", size = 159827599, upload-time = "2025-10-15T19:15:43.87Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/eb/09e31d107a5d00eb281aa7e6635ca463e9bca86515944e399480eadb71f8/triton-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5d3b3d480debf24eaa739623c9a42446b0b77f95593d30eb1f64cd2278cc1f0", size = 170333110, upload-time = "2025-10-13T16:37:49.588Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/f9/b6f60f978397c616fd8dacca2305759fe4f80d397b20ef72534803244bd5/triton-3.5.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8457b22148defefdcb7fa8144b05ce211b9faefad650a1ce85b23df488d5549c", size = 159926731, upload-time = "2025-10-15T19:15:49.682Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/78/949a04391c21956c816523678f0e5fa308eb5b1e7622d88c4e4ef5fceca0/triton-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f34bfa21c5b3a203c0f0eab28dcc1e49bd1f67d22724e77fb6665a659200a4ec", size = 170433488, upload-time = "2025-10-13T16:37:57.132Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/9b/30988039e1e84df7554fba24e6a734d2d0e847af33cabdf9b532b3c51456/triton-3.5.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da21fccceafc163e3a5e857abe34351ef76345af06cabf9637a914742671f0b", size = 159946647, upload-time = "2025-10-15T19:15:56.325Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/3a/e991574f3102147b642e49637e0281e9bb7c4ba254edb2bab78247c85e01/triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9e71db82261c4ffa3921cd050cd5faa18322d2d405c30eb56084afaff3b0833", size = 170476535, upload-time = "2025-10-13T16:38:05.18Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/85/e37f1197acb04c8f3d83851d23d5d6ed5060ef74580668b112e23fdfa203/triton-3.5.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:188da5b81fa2f8322c27fec1627703eac24cb9bb7ab0dfbe9925973bc1b070d3", size = 159958970, upload-time = "2025-10-15T19:16:01.717Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/29/10728de8a6e932e517c10773486b8e99f85d1b1d9dd87d9a9616e1fef4a1/triton-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6bb9aa5519c084a333acdba443789e50012a4b851cd486c54f0b8dc2a8d3a12", size = 170487289, upload-time = "2025-10-13T16:38:11.662Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/1d/38258f05010ac17a7b058c022911c9cae6526e149b7397134a048cf5a6c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03127d9b33aaf979c856676b394bc059ec1d68cb6da68ae03f62dd8ad77a04ae", size = 160073012, upload-time = "2025-10-15T19:16:07.477Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/38/db80e48b9220c9bce872b0f616ad0446cdf554a40b85c7865cbca99ab3c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c83f2343e1a220a716c7b3ab9fccfcbe3ad4020d189549200e2d2e8d5868bed9", size = 170577179, upload-time = "2025-10-13T16:38:17.865Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/fe/8f5771d00227f4eb1ee034f218ed427102b989366d2275fe3b3c105a3921/triton-3.5.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:468936651d383f4a6d10068d34a627505e13af55be5d002b9f27b987e7a5f0ac", size = 159957460, upload-time = "2025-10-15T19:16:12.626Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/60/1810655d1d856c9a4fcc90ee8966d85f552d98c53a6589f95ab2cbe27bb8/triton-3.5.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da0fa67ccd76c3dcfb0bffe1b1c57c685136a6bd33d141c24d9655d4185b1289", size = 170487949, upload-time = "2025-10-13T16:38:24.881Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/59/99edd103958fe6e42b50b9ad8ce4f223ddf4ccf475259cf7d2b53381dc6c/triton-3.5.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c7ceef21410229ac23173a28eee5cfc0e37c1dfdb8b4bc11ecda2e3ecec7c686", size = 160075629, upload-time = "2025-10-15T19:16:18.746Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/b7/1dec8433ac604c061173d0589d99217fe7bf90a70bdc375e745d044b8aad/triton-3.5.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:317fe477ea8fd4524a6a8c499fb0a36984a56d0b75bf9c9cb6133a1c56d5a6e7", size = 170580176, upload-time = "2025-10-13T16:38:31.14Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "triton"
|
||||
version = "3.5.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/2e/f95e673222afa2c7f0c687d8913e98fcf2589ef0b1405de76894e37fe18f/triton-3.5.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f63e34dcb32d7bd3a1d0195f60f30d2aee8b08a69a0424189b71017e23dfc3d2", size = 159821655, upload-time = "2025-11-11T17:51:44.09Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/6e/676ab5019b4dde8b9b7bab71245102fc02778ef3df48218b298686b9ffd6/triton-3.5.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5fc53d849f879911ea13f4a877243afc513187bc7ee92d1f2c0f1ba3169e3c94", size = 170320692, upload-time = "2025-11-11T17:40:46.074Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/dc/6ce44d055f2fc2403c4ec6b3cfd3a9b25f57b7d95efadccdea91497f8e81/triton-3.5.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da47169e30a779bade679ce78df4810fca6d78a955843d2ddb11f226adc517dc", size = 159928005, upload-time = "2025-11-11T17:51:50.008Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/72/ec90c3519eaf168f22cb1757ad412f3a2add4782ad3a92861c9ad135d886/triton-3.5.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:61413522a48add32302353fdbaaf92daaaab06f6b5e3229940d21b5207f47579", size = 170425802, upload-time = "2025-11-11T17:40:53.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/53/2bcc46879910991f09c063eea07627baef2bc62fe725302ba8f46a2c1ae5/triton-3.5.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:275a045b6ed670dd1bd005c3e6c2d61846c74c66f4512d6f33cc027b11de8fd4", size = 159940689, upload-time = "2025-11-11T17:51:55.938Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/50/9a8358d3ef58162c0a415d173cfb45b67de60176e1024f71fbc4d24c0b6d/triton-3.5.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d2c6b915a03888ab931a9fd3e55ba36785e1fe70cbea0b40c6ef93b20fc85232", size = 170470207, upload-time = "2025-11-11T17:41:00.253Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/ba/805684a992ee32d486b7948d36aed2f5e3c643fc63883bf8bdca1c3f3980/triton-3.5.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56765ffe12c554cd560698398b8a268db1f616c120007bfd8829d27139abd24a", size = 159955460, upload-time = "2025-11-11T17:52:01.861Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/46/8c3bbb5b0a19313f50edcaa363b599e5a1a5ac9683ead82b9b80fe497c8d/triton-3.5.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3f4346b6ebbd4fad18773f5ba839114f4826037c9f2f34e0148894cd5dd3dba", size = 170470410, upload-time = "2025-11-11T17:41:06.319Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/1e/7df59baef41931e21159371c481c31a517ff4c2517343b62503d0cd2be99/triton-3.5.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:02c770856f5e407d24d28ddc66e33cf026e6f4d360dcb8b2fabe6ea1fc758621", size = 160072799, upload-time = "2025-11-11T17:52:07.293Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/92/e97fcc6b2c27cdb87ce5ee063d77f8f26f19f06916aa680464c8104ef0f6/triton-3.5.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0b4d2c70127fca6a23e247f9348b8adde979d2e7a20391bfbabaac6aebc7e6a8", size = 170579924, upload-time = "2025-11-11T17:41:12.455Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/f9/0430e879c1e63a1016cb843261528fd3187c872c3a9539132efc39514753/triton-3.5.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f617aa7925f9ea9968ec2e1adaf93e87864ff51549c8f04ce658f29bbdb71e2d", size = 159956163, upload-time = "2025-11-11T17:52:12.999Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/e6/c595c35e5c50c4bc56a7bac96493dad321e9e29b953b526bbbe20f9911d0/triton-3.5.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d0637b1efb1db599a8e9dc960d53ab6e4637db7d4ab6630a0974705d77b14b60", size = 170480488, upload-time = "2025-11-11T17:41:18.222Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/1e/63d367c576c75919e268e4fbc33c1cb33b6dc12bb85e8bfe531c2a8bd5d3/triton-3.5.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8932391d7f93698dfe5bc9bead77c47a24f97329e9f20c10786bb230a9083f56", size = 160073620, upload-time = "2025-11-11T17:52:18.403Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/b5/b0d3d8b901b6a04ca38df5e24c27e53afb15b93624d7fd7d658c7cd9352a/triton-3.5.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bac7f7d959ad0f48c0e97d6643a1cc0fd5786fe61cb1f83b537c6b2d54776478", size = 170582192, upload-time = "2025-11-11T17:41:23.963Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user