From 3a5e0bc50b123f37220f1ace6813f080455d19f3 Mon Sep 17 00:00:00 2001 From: karpathy Date: Mon, 13 Oct 2025 06:49:24 -0700 Subject: [PATCH] initial commit --- .gitignore | 5 + .python-version | 1 + README.md | 132 ++ dev/generate_logo.html | 29 + dev/repackage_data_reference.py | 92 ++ nanochat/__init__.py | 0 nanochat/adamw.py | 77 ++ nanochat/checkpoint_manager.py | 146 +++ nanochat/common.py | 136 +++ nanochat/configurator.py | 56 + nanochat/core_eval.py | 262 ++++ nanochat/dataloader.py | 49 + nanochat/dataset.py | 128 ++ nanochat/engine.py | 343 ++++++ nanochat/execution.py | 350 ++++++ nanochat/gpt.py | 322 +++++ nanochat/logo.svg | 8 + nanochat/loss_eval.py | 63 + nanochat/muon.py | 187 +++ nanochat/report.py | 404 +++++++ nanochat/tokenizer.py | 395 ++++++ nanochat/ui.html | 394 ++++++ pyproject.toml | 55 + rustbpe/Cargo.lock | 458 +++++++ rustbpe/Cargo.toml | 15 + rustbpe/README.md | 5 + rustbpe/src/lib.rs | 476 ++++++++ scripts/base_eval.py | 180 +++ scripts/base_loss.py | 78 ++ scripts/base_train.py | 339 ++++++ scripts/chat_cli.py | 99 ++ scripts/chat_eval.py | 251 ++++ scripts/chat_rl.py | 331 +++++ scripts/chat_sft.py | 281 +++++ scripts/chat_web.py | 198 +++ scripts/mid_train.py | 289 +++++ scripts/tok_eval.py | 265 ++++ scripts/tok_train.py | 106 ++ speedrun.sh | 133 ++ tasks/arc.py | 48 + tasks/common.py | 147 +++ tasks/gsm8k.py | 117 ++ tasks/humaneval.py | 97 ++ tasks/mmlu.py | 60 + tasks/smoltalk.py | 46 + tests/test_rustbpe.py | 635 ++++++++++ uv.lock | 2004 +++++++++++++++++++++++++++++++ 47 files changed, 10292 insertions(+) create mode 100644 .gitignore create mode 100644 .python-version create mode 100644 README.md create mode 100644 dev/generate_logo.html create mode 100644 dev/repackage_data_reference.py create mode 100644 nanochat/__init__.py create mode 100644 nanochat/adamw.py create mode 100644 nanochat/checkpoint_manager.py create mode 100644 nanochat/common.py create mode 100644 nanochat/configurator.py create mode 100644 nanochat/core_eval.py create mode 100644 nanochat/dataloader.py create mode 100644 nanochat/dataset.py create mode 100644 nanochat/engine.py create mode 100644 nanochat/execution.py create mode 100644 nanochat/gpt.py create mode 100644 nanochat/logo.svg create mode 100644 nanochat/loss_eval.py create mode 100644 nanochat/muon.py create mode 100644 nanochat/report.py create mode 100644 nanochat/tokenizer.py create mode 100644 nanochat/ui.html create mode 100644 pyproject.toml create mode 100644 rustbpe/Cargo.lock create mode 100644 rustbpe/Cargo.toml create mode 100644 rustbpe/README.md create mode 100644 rustbpe/src/lib.rs create mode 100644 scripts/base_eval.py create mode 100644 scripts/base_loss.py create mode 100644 scripts/base_train.py create mode 100644 scripts/chat_cli.py create mode 100644 scripts/chat_eval.py create mode 100644 scripts/chat_rl.py create mode 100644 scripts/chat_sft.py create mode 100644 scripts/chat_web.py create mode 100644 scripts/mid_train.py create mode 100644 scripts/tok_eval.py create mode 100644 scripts/tok_train.py create mode 100644 speedrun.sh create mode 100644 tasks/arc.py create mode 100644 tasks/common.py create mode 100644 tasks/gsm8k.py create mode 100644 tasks/humaneval.py create mode 100644 tasks/mmlu.py create mode 100644 tasks/smoltalk.py create mode 100644 tests/test_rustbpe.py create mode 100644 uv.lock diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b14ecde --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.venv/ +__pycache__/ +*.pyc +rustbpe/target/ +dev-ignore/ diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..c8cfe39 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10 diff --git a/README.md b/README.md new file mode 100644 index 0000000..df997ed --- /dev/null +++ b/README.md @@ -0,0 +1,132 @@ +# nanochat + +> The best ChatGPT that $100 can buy. + +This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs. + +## Quick start + +The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: + +```bash +bash speedrun.sh +``` + +Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`): + +```bash +screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh +``` + +See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activative`), and serve it: + +```bash +python -m scripts.chat_web +``` + +And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :). + +You can also `cat report.md` file which appeared in the project directory and contains the "report card" of the run, i.e. a bunch of evaluations and metrics. At the vert end, you'll see a summary table, for example: + +--- + +- Characters: 333,989 +- Lines: 8,304 +- Files: 44 +- Tokens (approx): 83,497 +- Dependencies (uv.lock lines): 2,004 + +| Metric | BASE | MID | SFT | RL | +|-----------------|----------|----------|----------|----------| +| CORE | 0.2219 | - | - | - | +| ARC-Challenge | - | 0.2875 | 0.2807 | - | +| ARC-Easy | - | 0.3561 | 0.3876 | - | +| GSM8K | - | 0.0250 | 0.0455 | 0.0758 | +| HumanEval | - | 0.0671 | 0.0854 | - | +| MMLU | - | 0.3111 | 0.3151 | - | +| ChatCORE | - | 0.0730 | 0.0884 | - | + +Total wall clock time: 3h51m + +--- + +(Your table might be missing the RL number by default). For a lot more information around the speedrun script and what to look for and expect, please refer to the walkthrough that I posted in Discussions of the repo: ["nanochat: the best ChatGPT that $100 can buy"](). + +## Bigger models + +Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet. + +That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes: + +```bash +... +# you'll need to download more data shards for pretraining +# get the number of parameters, multiply 20 to get tokens, multiply by 4.8 to get chars, +# divide by 250 million to get number of shards. todo need to improve this... +python -m nanochat.dataset -n 450 & +... +# use --depth to increase model size. to not oom, halve device bath size 32 -> 16: +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16 +... +# make sure to use the same later during midtraining: +torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 +``` + +That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensates by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute). + +And a bit more about computing environments that will run nanochat: + +- The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower. +- All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer. +- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. +- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering. + +## Questions + +nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so: + +```bash +files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml > packaged.txt +``` + +This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files. + +Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. + +## Tests + +I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as: + +```bash +python -m pytest tests/test_rustbpe.py -v -s +``` + +## Contributing + +nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. + +## Acknowledgements + +- The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining. +- nanochat is also inspired by [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt), which gamified the nanoGPT repo with clear metrics and a leaderboard, and borrows a lot of its ideas and some implementation for pretraining. +- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk. +- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project. +- Thank you to chief LLM whisperer πŸ§™β€β™‚οΈ Alec Radford for advice/guidance. + +## Cite + +If you find nanochat helpful in your research cite simply as: + +```bibtex +@misc{nanochat, + author = {Andrej Karpathy}, + title = {nanochat: The best ChatGPT that $100 can buy}, + year = {2025}, + publisher = {GitHub}, + url = {https://github.com/karpathy/nanochat} +} +``` + +## License + +MIT diff --git a/dev/generate_logo.html b/dev/generate_logo.html new file mode 100644 index 0000000..edfa99e --- /dev/null +++ b/dev/generate_logo.html @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dev/repackage_data_reference.py b/dev/repackage_data_reference.py new file mode 100644 index 0000000..32980a8 --- /dev/null +++ b/dev/repackage_data_reference.py @@ -0,0 +1,92 @@ +""" +Repackage the FinewebEdu-100B dataset into shards: + +- each shard is ~100MB in size (after zstd compression) +- parquets are written with row group size of 1000 +- shuffle the dataset + +This will be uploaded to HuggingFace for hosting. +The big deal is that our DataLoader will be able to stream +the data and cache it along the way on disk, decreasing the +training latency. + +NOTE: This file is meant only as reference/documentation of the +dataset preparation and it is not used during the project runtime. +""" +import os +import time + +from datasets import load_dataset +import pyarrow.parquet as pq +import pyarrow as pa + +# Source dataset +dataset_kwargs = { + "path": "HuggingFaceFW/fineweb-edu", + "split": "train", + "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total +} +ds = load_dataset(**dataset_kwargs) + +# Shuffle to scramble the order +ds = ds.shuffle(seed=42) +ndocs = len(ds) # total number of documents to process +print(f"Total number of documents: {ndocs}") + +# Repackage into parquet files +output_dir = "/home/ubuntu/.cache/nanochat/base_data" +os.makedirs(output_dir, exist_ok=True) + +# Write to parquet files +chars_per_shard = 250_000_000 +row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later +shard_docs = [] +shard_index = 0 +shard_characters = 0 +total_docs_processed = 0 +total_time_spent = 0 +t0 = time.time() +for doc in ds: + text = doc['text'] + shard_docs.append(text) + shard_characters += len(text) + collected_enough_chars = shard_characters >= chars_per_shard + docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0 + if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed) + shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet") + shard_table = pa.Table.from_pydict({"text": shard_docs}) + pq.write_table( + shard_table, + shard_path, + row_group_size=row_group_size, + use_dictionary=False, # this is usually used for categorical data + compression="zstd", # Valid values: {β€˜NONE’, β€˜SNAPPY’, β€˜GZIP’, β€˜BROTLI’, β€˜LZ4’, β€˜ZSTD’} + compression_level=3, + write_statistics=False, # not needed for text + ) + t1 = time.time() + dt = t1 - t0 # for this shard alone + t0 = t1 + total_docs_processed += len(shard_docs) + total_time_spent += dt + remaining_docs = ndocs - total_docs_processed + avg_time_per_doc = total_time_spent / total_docs_processed + remaining_time = remaining_docs * avg_time_per_doc + remaining_time_hours = remaining_time / 3600 + print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h") + shard_docs = [] + shard_characters = 0 + shard_index += 1 + +# Demonstration of how the data was later uploaded to HuggingFace +def upload(): + import os + from huggingface_hub import HfApi + token = os.getenv("HF_TOKEN") + api = HfApi(token=token) + api.upload_large_folder( + folder_path=output_dir, + repo_id="karpathy/fineweb-edu-100b-shuffle", + repo_type="dataset", + ) +# upload() diff --git a/nanochat/__init__.py b/nanochat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nanochat/adamw.py b/nanochat/adamw.py new file mode 100644 index 0000000..07b82de --- /dev/null +++ b/nanochat/adamw.py @@ -0,0 +1,77 @@ +""" +Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. +Not a general optimizer! But works for our specific use. +""" +import torch +import torch.distributed as dist +from torch import Tensor + + +class DistAdamW(torch.optim.Optimizer): + """ + Distributed AdamW optimizer. + In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction + """ + def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(param_groups, defaults) + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_reduce_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_reduce_futures).wait() diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py new file mode 100644 index 0000000..961ebd5 --- /dev/null +++ b/nanochat/checkpoint_manager.py @@ -0,0 +1,146 @@ +""" +Utilities for saving and loading model/optim/state checkpoints. +""" +import os +import re +import glob +import json +import logging +import torch + +from nanochat.common import get_base_dir +from nanochat.gpt import GPT, GPTConfig +from nanochat.tokenizer import get_tokenizer +from nanochat.common import setup_default_logging + +# Set up logging +setup_default_logging() +logger = logging.getLogger(__name__) +def log0(message): + if int(os.environ.get('RANK', 0)) == 0: + logger.info(message) + +def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): + assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now + os.makedirs(checkpoint_dir, exist_ok=True) + # Save the model state (parameters) + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") + torch.save(model_data, model_path) + log0(f"Saved model file to: {model_path}") + # Save the optimizer state (useful for SFT or any other fine-tuning) + if optimizer_data is not None: + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + torch.save(optimizer_data, optimizer_path) + log0(f"Saved optimizer file to: {optimizer_path}") + # Save the metadata dict as json + meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + with open(meta_path, "w") as f: + json.dump(meta_data, f, indent=2) + log0(f"Saved metadata file to: {meta_path}") + + +def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): + # Load the model state + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") + model_data = torch.load(model_path, map_location=device) + # Load the optimizer state if requested + optimizer_data = None + if load_optimizer: + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_data = torch.load(optimizer_path, map_location=device) + # Load the metadata + meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + with open(meta_path, "r") as f: + meta_data = json.load(f) + return model_data, optimizer_data, meta_data + + +def build_model(checkpoint_dir, step, device, phase): + """ + A bunch of repetitive code to build a model from a given checkpoint. + Returns: + - base model - uncompiled, not wrapped in DDP + - tokenizer + - meta data saved during base model training + """ + assert phase in ["train", "eval"], f"Invalid phase: {phase}" + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) + # Hack: fix torch compile issue, which prepends all keys with _orig_mod. + model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} + model_config_kwargs = meta_data["model_config"] + log0(f"Building model with config: {model_config_kwargs}") + model_config = GPTConfig(**model_config_kwargs) + with torch.device("meta"): + model = GPT(model_config) + # Load the model state + model.to_empty(device=device) + model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.load_state_dict(model_data, strict=True, assign=True) + # Put the model in the right training phase / mode + if phase == "eval": + model.eval() + else: + model.train() + # Load the Tokenizer + tokenizer = get_tokenizer() + # Sanity check: compatibility between model and tokenizer + assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"] + return model, tokenizer, meta_data + + +def find_largest_model(checkpoint_dir): + # attempt to guess the model tag: take the biggest model available + model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] + if not model_tags: + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + # 1) normally all model tags are of the form d, try that first: + candidates = [] + for model_tag in model_tags: + match = re.match(r"d(\d+)", model_tag) + if match: + model_depth = int(match.group(1)) + candidates.append((model_depth, model_tag)) + if candidates: + candidates.sort(key=lambda x: x[0], reverse=True) + return candidates[0][1] + # 2) if that failed, take the most recently updated model: + candidates.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x[1])), reverse=True) + return candidates[0][1] + + +def find_last_step(checkpoint_dir): + # Look into checkpoint_dir and find model_.pt with the highest step + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) + if not checkpoint_files: + raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) + return last_step + +# ----------------------------------------------------------------------------- +# convenience functions that take into account nanochat's directory structure + +def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): + if model_tag is None: + # guess the model tag by defaulting to the largest model + model_tag = find_largest_model(checkpoints_dir) + log0(f"No model tag provided, guessing model tag: {model_tag}") + checkpoint_dir = os.path.join(checkpoints_dir, model_tag) + if step is None: + # guess the step by defaulting to the last step + step = find_last_step(checkpoint_dir) + assert step is not None, f"No checkpoints found in {checkpoint_dir}" + # build the model + log0(f"Loading model from {checkpoint_dir} with step {step}") + model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) + return model, tokenizer, meta_data + +def load_model(source, *args, **kwargs): + model_dir = { + "base": "base_checkpoints", + "mid": "mid_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[source] + base_dir = get_base_dir() + checkpoints_dir = os.path.join(base_dir, model_dir) + return load_model_from_dir(checkpoints_dir, *args, **kwargs) diff --git a/nanochat/common.py b/nanochat/common.py new file mode 100644 index 0000000..8b10df9 --- /dev/null +++ b/nanochat/common.py @@ -0,0 +1,136 @@ +""" +Common utilities for nanochat. +""" + +import os +import re +import logging +import torch +import torch.distributed as dist + +class ColoredFormatter(logging.Formatter): + """Custom formatter that adds colors to log messages.""" + # ANSI color codes + COLORS = { + 'DEBUG': '\033[36m', # Cyan + 'INFO': '\033[32m', # Green + 'WARNING': '\033[33m', # Yellow + 'ERROR': '\033[31m', # Red + 'CRITICAL': '\033[35m', # Magenta + } + RESET = '\033[0m' + BOLD = '\033[1m' + def format(self, record): + # Add color to the level name + levelname = record.levelname + if levelname in self.COLORS: + record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" + # Format the message + message = super().format(record) + # Add color to specific parts of the message + if levelname == 'INFO': + # Highlight numbers and percentages + message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) + message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) + return message + +def setup_default_logging(): + handler = logging.StreamHandler() + handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + logging.basicConfig( + level=logging.INFO, + handlers=[handler] + ) + +setup_default_logging() +logger = logging.getLogger(__name__) + +def get_base_dir(): + # co-locate nanochat intermediates with other cached data in ~/.cache (by default) + if os.environ.get("NANOCHAT_BASE_DIR"): + nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR") + else: + home_dir = os.path.expanduser("~") + cache_dir = os.path.join(home_dir, ".cache") + nanochat_dir = os.path.join(cache_dir, "nanochat") + os.makedirs(nanochat_dir, exist_ok=True) + return nanochat_dir + +def print0(s="",**kwargs): + ddp_rank = int(os.environ.get('RANK', 0)) + if ddp_rank == 0: + print(s, **kwargs) + +def print_banner(): + # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ + banner = """ + β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ + β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆ + β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ +β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘ + β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ + β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆ + β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ +β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘ +""" + print0(banner) + +def is_ddp(): + # TODO is there a proper way + return int(os.environ.get('RANK', -1)) != -1 + +def get_dist_info(): + if is_ddp(): + assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) + ddp_rank = int(os.environ['RANK']) + ddp_local_rank = int(os.environ['LOCAL_RANK']) + ddp_world_size = int(os.environ['WORLD_SIZE']) + return True, ddp_rank, ddp_local_rank, ddp_world_size + else: + return False, 0, 0, 1 + +def compute_init(): + """Basic initialization that we keep doing over and over, so make common.""" + + # CUDA is currently required + assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" + + # Reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed(42) + # skipping full reproducibility for now, possibly investigate slowdown later + # torch.use_deterministic_algorithms(True) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + # Precision + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + + # Distributed setup: Distributed Data Parallel (DDP), optional + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + if ddp: + device = torch.device("cuda", ddp_local_rank) + torch.cuda.set_device(device) # make "cuda" default to this device + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + else: + device = torch.device("cuda") + + if ddp_rank == 0: + logger.info(f"Distributed world size: {ddp_world_size}") + + return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device + +def compute_cleanup(): + """Companion function to compute_init, to clean things up before script exit""" + if is_ddp(): + dist.destroy_process_group() + +class DummyWandb: + """Useful if we wish to not use wandb but have all the same signatures""" + def __init__(self): + pass + def log(self, *args, **kwargs): + pass + def finish(self): + pass diff --git a/nanochat/configurator.py b/nanochat/configurator.py new file mode 100644 index 0000000..ec1b76d --- /dev/null +++ b/nanochat/configurator.py @@ -0,0 +1,56 @@ +""" +Poor Man's Configurator. Probably a terrible idea. Example usage: +$ python train.py config/override_file.py --batch_size=32 +this will first run config/override_file.py, then override batch_size to 32 + +The code in this file will be run as follows from e.g. train.py: +>>> exec(open('configurator.py').read()) + +So it's not a Python module, it's just shuttling this code away from train.py +The code in this script then overrides the globals() + +I know people are not going to love this, I just really dislike configuration +complexity and having to prepend config. to every single variable. If someone +comes up with a better simple Python solution I am all ears. +""" + +import os +import sys +from ast import literal_eval + +def print0(s="",**kwargs): + ddp_rank = int(os.environ.get('RANK', 0)) + if ddp_rank == 0: + print(s, **kwargs) + +for arg in sys.argv[1:]: + if '=' not in arg: + # assume it's the name of a config file + assert not arg.startswith('--') + config_file = arg + print0(f"Overriding config with {config_file}:") + with open(config_file) as f: + print0(f.read()) + exec(open(config_file).read()) + else: + # assume it's a --key=value argument + assert arg.startswith('--') + key, val = arg.split('=') + key = key[2:] + if key in globals(): + try: + # attempt to eval it it (e.g. if bool, number, or etc) + attempt = literal_eval(val) + except (SyntaxError, ValueError): + # if that goes wrong, just use the string + attempt = val + # ensure the types match ok + if globals()[key] is not None: + attempt_type = type(attempt) + default_type = type(globals()[key]) + assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" + # cross fingers + print0(f"Overriding: {key} = {attempt}") + globals()[key] = attempt + else: + raise ValueError(f"Unknown config key: {key}") diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py new file mode 100644 index 0000000..f3c9a9f --- /dev/null +++ b/nanochat/core_eval.py @@ -0,0 +1,262 @@ +""" +Functions for evaluating the CORE metric, as described in the DCLM paper. +https://arxiv.org/abs/2406.11794 + +TODOs: +- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. +""" +import random + +from jinja2 import Template +import torch +import torch.distributed as dist + +# ----------------------------------------------------------------------------- +# Prompt rendering utilities + +def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): + """Render complete prompts for a multiple choice question""" + template_str = """ +{%- for example in fewshot_examples -%} +{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }} + +{% endfor -%} +{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip() + template = Template(template_str) + fewshot_examples = fewshot_examples or [] + context = { + 'fewshot_examples': fewshot_examples, + 'continuation_delimiter': continuation_delimiter, + 'item': item + } + prompts = [template.render(choice=choice, **context) for choice in item['choices']] + return prompts + + +def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): + """Render complete prompts for a schema question""" + template_str = """ +{%- for example in fewshot_examples -%} +{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }} + +{% endfor -%} +{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip() + template = Template(template_str) + fewshot_examples = fewshot_examples or [] + context = { + 'fewshot_examples': fewshot_examples, + 'continuation_delimiter': continuation_delimiter, + 'item': item + } + prompts = [template.render(context=context_option, **context) + for context_option in item['context_options']] + return prompts + + +def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): + """ + Render complete prompt for a language modeling task. + Notice that we manually trim the context in the template, + which in some datasets seems to have trailing whitespace (which we don't want). + """ + template_str = """ +{%- for example in fewshot_examples -%} +{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }} + +{% endfor -%} +{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip() + template = Template(template_str) + fewshot_examples = fewshot_examples or [] + context = { + 'fewshot_examples': fewshot_examples, + 'continuation_delimiter': continuation_delimiter, + 'item': item + } + # Return two prompts: without and with the continuation + prompt_without = template.render(include_continuation=False, **context) + prompt_with = template.render(include_continuation=True, **context) + # Due to the way the data seems to be stored, I think I need to strip in the case of LM here. + # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next + # token in prompt_with), meaning we don't get a nice and clean prefix in the token space + # to detect the final continuation. Tokenizers... + prompt_without = prompt_without.strip() + return [prompt_without, prompt_with] + + +def find_common_length(token_sequences, direction='left'): + """ + Find the length of the common prefix or suffix across token sequences + - direction: 'left' for prefix, 'right' for suffix + """ + min_len = min(len(seq) for seq in token_sequences) + indices = { + 'left': range(min_len), + 'right': range(-1, -min_len-1, -1) + }[direction] + # Find the first position where the token sequences differ + for i, idx in enumerate(indices): + token = token_sequences[0][idx] + if not all(seq[idx] == token for seq in token_sequences): + return i + return min_len + + +def stack_sequences(tokens, pad_token_id): + """Stack up a list of token sequences, pad to longest on the right""" + bsz, seq_len = len(tokens), max(len(x) for x in tokens) + input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) + for i, x in enumerate(tokens): + input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long) + return input_ids + + +def batch_sequences_mc(tokenizer, prompts): + # In multiple choice, contexts are the same but the continuation is different (common prefix) + tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) + # figure out the start and end of each continuation + answer_start_idx = find_common_length(tokens, direction='left') + start_indices = [answer_start_idx] * len(prompts) + end_indices = [len(x) for x in tokens] + return tokens, start_indices, end_indices + + +def batch_sequences_schema(tokenizer, prompts): + # In schema tasks, contexts vary but continuation is the same (common suffix) + tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) + # figure out the start and end of each context + suffix_length = find_common_length(tokens, direction='right') + end_indices = [len(x) for x in tokens] + start_indices = [ei - suffix_length for ei in end_indices] + return tokens, start_indices, end_indices + + +def batch_sequences_lm(tokenizer, prompts): + # In LM tasks, we have two prompts: without and with continuation + tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) + tokens_without, tokens_with = tokens + start_idx, end_idx = len(tokens_without), len(tokens_with) + assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with" + assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with" + # we only need the with continuation prompt in the LM task, i.e. batch size of 1 + return [tokens_with], [start_idx], [end_idx] + + +@torch.no_grad() +def forward_model(model, input_ids): + """ + Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. + The last column of losses is set to nan because we don't have autoregressive targets there. + """ + batch_size, seq_len = input_ids.size() + outputs = model(input_ids) + # Roll the tensor to the left by one position to get the (autoregressive) target ids + target_ids = torch.roll(input_ids, shifts=-1, dims=1) + # Calculate cross entropy at all positions + losses = torch.nn.functional.cross_entropy( + outputs.view(batch_size * seq_len, -1), + target_ids.view(batch_size * seq_len), + reduction='none' + ).view(batch_size, seq_len) + # Set the last column to be nan because there is no autoregressive loss there + losses[:, -1] = float('nan') + # Get the argmax predictions at each position + predictions = outputs.argmax(dim=-1) + return losses, predictions + + +@torch.no_grad() +def evaluate_example(idx, model, tokenizer, data, device, task_meta): + """Evaluate a single example, return True if correct, False otherwise""" + item = data[idx] + task_type = task_meta['task_type'] + num_fewshot = task_meta['num_fewshot'] + continuation_delimiter = task_meta['continuation_delimiter'] + + # Sample few-shot examples (excluding current item) + fewshot_examples = [] + if num_fewshot > 0: + rng = random.Random(1234 + idx) + available_indices = [i for i in range(len(data)) if i != idx] + fewshot_indices = rng.sample(available_indices, num_fewshot) + fewshot_examples = [data[i] for i in fewshot_indices] + + # Render prompts and batch sequences based on task type + if task_type == 'multiple_choice': + prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts) + elif task_type == 'schema': + prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts) + elif task_type == 'language_modeling': + prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts) + else: + raise ValueError(f"Unsupported task type: {task_type}") + + # Some models can't forward sequences beyond a certain length (e.g. GPT-2) + # In these cases, we have to truncate sequences to max length and adjust the indices + if hasattr(model, 'max_seq_len') and model.max_seq_len is not None: + max_tokens = model.max_seq_len + new_tokens, new_start_idxs, new_end_idxs = [], [], [] + for t, s, e in zip(tokens, start_idxs, end_idxs): + if len(t) > max_tokens: + num_to_crop = len(t) - max_tokens + new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens + new_start_idxs.append(s - num_to_crop) # shift the indices down + new_end_idxs.append(e - num_to_crop) + assert s - num_to_crop >= 0, "this should never happen right?" + assert e - num_to_crop >= 0, "this should never happen right?" + else: + new_tokens.append(t) # keep unchanged + new_start_idxs.append(s) + new_end_idxs.append(e) + tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs + + # Stack up all the sequences into a batch + pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok + input_ids = stack_sequences(tokens, pad_token_id) + input_ids = input_ids.to(device) + + # Forward the model, get the autoregressive loss and argmax prediction at each token + losses, predictions = forward_model(model, input_ids) + + # See if the losses/predictions come out correctly + if task_type == 'language_modeling': + # language modeling task is currently always batch size 1 + si = start_idxs[0] + ei = end_idxs[0] + # predictions[i] predict input_ids[i+1] autoregressively + predicted_tokens = predictions[0, si-1:ei-1] + actual_tokens = input_ids[0, si:ei] + is_correct = torch.all(predicted_tokens == actual_tokens).item() + elif task_type in ['multiple_choice', 'schema']: + # For MC/schema: find the option with lowest average loss + mean_losses = [losses[i, si-1:ei-1].mean().item() + for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] + pred_idx = mean_losses.index(min(mean_losses)) + is_correct = pred_idx == item['gold'] + else: + raise ValueError(f"Unsupported task type: {task_type}") + + return is_correct + + +def evaluate_task(model, tokenizer, data, device, task_meta): + """ + This function is responsible for evaluating one task across many examples. + It also handles dispatch to all processes if the script is run with torchrun. + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + correct = torch.zeros(len(data), dtype=torch.float32, device=device) + # stride the examples to each rank + for idx in range(rank, len(data), world_size): + is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) + correct[idx] = float(is_correct) + # sync results across all the processes if running distributed + if world_size > 1: + dist.barrier() + dist.all_reduce(correct, op=dist.ReduceOp.SUM) + # compute the mean + mean_correct = correct.mean().item() + return mean_correct diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py new file mode 100644 index 0000000..c1636b1 --- /dev/null +++ b/nanochat/dataloader.py @@ -0,0 +1,49 @@ +from collections import deque + +import torch + +from nanochat.common import get_dist_info +from nanochat.dataset import parquets_iter_batched +from nanochat.tokenizer import get_tokenizer + +def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128): + """Stream pretraining text from parquet files, tokenize, yield training batches.""" + assert split in ["train", "val"], "split must be 'train' or 'val'" + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + needed_tokens = B * T + 1 # +1 is because we also need the target at the last token + # get the tokenizer and the bos token + tokenizer = get_tokenizer() + bos_token = tokenizer.get_bos_token_id() + # scratch buffer holds the tokens for one iteration + token_buffer = deque() # we stream tokens on the right and pop from the left + scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) + + # infinite iterator over document batches + def document_batches(): + while True: + # batch will iterate in group size of the parquet files, usually e.g. 1024 rows + for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size): + # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows + for i in range(0, len(batch), tokenizer_batch_size): + yield batch[i:i+tokenizer_batch_size] + batches = document_batches() + + batch_index = 0 + while True: + # Accumulate enough tokens for one iteration before yielding. + while len(token_buffer) < needed_tokens: + doc_batch = next(batches) + token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) + for tokens in token_lists: + token_buffer.extend(tokens) + batch_index += 1 + # Move tokens from the deque into the scratch buffer + for i in range(needed_tokens): + scratch[i] = token_buffer.popleft() + # Create the inputs/targets as 1D tensors + inputs_cpu = scratch[:-1].to(dtype=torch.int32) + targets_cpu = scratch[1:] + # Reshape to 2D and move to GPU async + inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) + targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) + yield inputs, targets diff --git a/nanochat/dataset.py b/nanochat/dataset.py new file mode 100644 index 0000000..602daed --- /dev/null +++ b/nanochat/dataset.py @@ -0,0 +1,128 @@ +""" +The base/pretraining dataset is a set of parquet files. +This file contains utilities for: +- iterating over the parquet files and yielding documents from it +- download the files on demand if they are not on disk + +For details of how the dataset was prepared, see `repackage_data_reference.py`. +""" + +import os +import argparse +import time +import requests +import pyarrow.parquet as pq +from multiprocessing import Pool + +from nanochat.common import get_base_dir + +# ----------------------------------------------------------------------------- +# The specifics of the current pretraining dataset + +# The URL on the internet where the data is hosted and downloaded from on demand +BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" +MAX_SHARD = 1822 # the last datashard is shard_01822.parquet +index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames +base_dir = get_base_dir() +DATA_DIR = os.path.join(base_dir, "base_data") +os.makedirs(DATA_DIR, exist_ok=True) + +# ----------------------------------------------------------------------------- +# These functions are useful utilities to other modules, can/should be imported + +def list_parquet_files(data_dir=None): + """ Looks into a data dir and returns full paths to all parquet files. """ + data_dir = DATA_DIR if data_dir is None else data_dir + parquet_files = sorted([ + f for f in os.listdir(data_dir) + if f.endswith('.parquet') and not f.endswith('.tmp') + ]) + parquet_paths = [os.path.join(data_dir, f) for f in parquet_files] + return parquet_paths + +def parquets_iter_batched(split, start=0, step=1): + """ + Iterate through the dataset, in batches of underlying row_groups for efficiency. + - split can be "train" or "val". the last parquet file will be val. + - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size + """ + assert split in ["train", "val"], "split must be 'train' or 'val'" + parquet_paths = list_parquet_files() + parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + for filepath in parquet_paths: + pf = pq.ParquetFile(filepath) + for rg_idx in range(start, pf.num_row_groups, step): + rg = pf.read_row_group(rg_idx) + texts = rg.column('text').to_pylist() + yield texts + +# ----------------------------------------------------------------------------- +def download_single_file(index): + """ Downloads a single file index, with some backoff """ + + # Construct the local filepath for this file and skip if it already exists + filename = index_to_filename(index) + filepath = os.path.join(DATA_DIR, filename) + if os.path.exists(filepath): + print(f"Skipping {filepath} (already exists)") + return True + + # Construct the remote URL for this file + url = f"{BASE_URL}/{filename}" + print(f"Downloading {filename}...") + + # Download with retries + max_attempts = 5 + for attempt in range(1, max_attempts + 1): + try: + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + # Write to temporary file first + temp_path = filepath + f".tmp" + with open(temp_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks + if chunk: + f.write(chunk) + # Move temp file to final location + os.rename(temp_path, filepath) + print(f"Successfully downloaded {filename}") + return True + + except (requests.RequestException, IOError) as e: + print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") + # Clean up any partial files + for path in [filepath + f".tmp", filepath]: + if os.path.exists(path): + try: + os.remove(path) + except: + pass + # Try a few times with exponential backoff: 2^attempt seconds + if attempt < max_attempts: + wait_time = 2 ** attempt + print(f"Waiting {wait_time} seconds before retry...") + time.sleep(wait_time) + else: + print(f"Failed to download {filename} after {max_attempts} attempts") + return False + + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") + parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") + parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") + args = parser.parse_args() + + num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) + ids_to_download = list(range(num)) + print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") + print(f"Target directory: {DATA_DIR}") + print() + with Pool(processes=args.num_workers) as pool: + results = pool.map(download_single_file, ids_to_download) + + # Report results + successful = sum(1 for success in results if success) + print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}") diff --git a/nanochat/engine.py b/nanochat/engine.py new file mode 100644 index 0000000..de1253a --- /dev/null +++ b/nanochat/engine.py @@ -0,0 +1,343 @@ +""" +Engine for efficient inference of our models. + +Everything works around token sequences: +- The user can send token sequences to the engine +- The engine returns the next token + +Notes: +- The engine knows nothing about tokenization, it's purely token id sequences. + +The whole thing is made as efficient as possible. +""" + +import torch +import torch.nn.functional as F +import signal +import warnings +from contextlib import contextmanager +from collections import deque +from nanochat.common import compute_init +from nanochat.checkpoint_manager import load_model + +# ----------------------------------------------------------------------------- +# Calculator tool helpers +@contextmanager +def timeout(duration, formula): + def timeout_handler(signum, frame): + raise Exception(f"'{formula}': timed out after {duration} seconds") + + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(duration) + yield + signal.alarm(0) + +def eval_with_timeout(formula, max_time=3): + try: + with timeout(max_time, formula): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", SyntaxWarning) + return eval(formula) + except Exception as e: + signal.alarm(0) + # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage + return None + +def use_calculator(expr): + """Evaluate a math expression safely.""" + expr = expr.replace(",", "") + if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars + return None + if "**" in expr: # for now disallow power operator, could be very expensive + return None + return eval_with_timeout(expr) + +# ----------------------------------------------------------------------------- +class KVCache: + """ + Works hand-in-hand with the GPT model to maintain the KV cache. + Note that the .pos advances automatically after the last layer of the Transformer inserts. + """ + + def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers): + # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer. + self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim) + self.kv_cache = None + self.pos = 0 # current position in time in the cache + + def reset(self): + self.pos = 0 + + def get_pos(self): + return self.pos + + def prefill(self, other): + """ + Prefill given another KV cache. Optionally expand along batch dim. + This is used when we do batch 1 prefill and then want to generate + multiple samples in parallel from there. + """ + # 1) validate the shapes + assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" + assert other.kv_cache is not None, "Cannot prefill with a None KV cache" + for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): + if ix in [0, 1, 3, 5]: + # num_layers, batch_size, num_heads, head_dim must match + assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}" + elif ix == 2: + # batch_size can be expanded + assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" + elif ix == 4: + # seq_len: self must be longer than other + assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}" + # 2) initialize the cache + dtype, device = other.kv_cache.dtype, other.kv_cache.device + self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) + # 3) copy the data over + self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache + # 4) update the pos + self.pos = other.pos + + def insert_kv(self, layer_idx, k, v): + # Lazy initialize the cache here because we need to know the dtype/device + if self.kv_cache is None: + self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) + # Insert new keys/values to the cache and return the full cache so far + B, H, T_add, D = k.size() + t0, t1 = self.pos, self.pos + T_add + # Dynamically grow the cache if needed + if t1 > self.kv_cache.size(4): + t_needed = t1 + 1024 # as much as we need plus buffer of 1024 + t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 + current_shape = list(self.kv_cache.shape) + current_shape[4] = t_needed + self.kv_cache.resize_(current_shape) + # Insert k, v into the cache + self.kv_cache[layer_idx, 0, :, :, t0:t1] = k + self.kv_cache[layer_idx, 1, :, :, t0:t1] = v + # Return the full cached keys/values up to current position (as a view) + key_view = self.kv_cache[layer_idx, 0, :, :, :t1] + value_view = self.kv_cache[layer_idx, 1, :, :, :t1] + # Increment pos after the last layer of the Transformer processes + if layer_idx == self.kv_cache.size(0) - 1: + self.pos = t1 + return key_view, value_view + + +# ----------------------------------------------------------------------------- +@torch.inference_mode() +def sample_next_token(logits, rng, temperature=1.0, top_k=None): + """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1).""" + assert temperature >= 0.0, "temperature must be non-negative" + if temperature == 0.0: + return torch.argmax(logits, dim=-1, keepdim=True) + if top_k is not None: + k = min(top_k, logits.size(-1)) + vals, idx = torch.topk(logits, k, dim=-1) + vals = vals / temperature + probs = F.softmax(vals, dim=-1) + choice = torch.multinomial(probs, num_samples=1, generator=rng) + return idx.gather(1, choice) + else: + logits = logits / temperature + probs = F.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1, generator=rng) + +# ----------------------------------------------------------------------------- + +class RowState: + # Per-row state tracking during generation + def __init__(self, current_tokens=None): + self.current_tokens = current_tokens or [] # Current token sequence for this row + self.forced_tokens = deque() # Queue of tokens to force inject + self.in_python_block = False # Whether we are inside a python block + self.python_expr_tokens = [] # Tokens of the current python expression + self.completed = False # Whether this row has completed generation + +class Engine: + + def __init__(self, model, tokenizer): + self.model = model + self.tokenizer = tokenizer # needed for tool use + + @torch.inference_mode() + def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): + """Same as generate, but does single prefill and then clones the KV cache.""" + assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" + device = self.model.get_device() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + + # Get the special tokens we need to coordinate the tool use state machine + get_special = lambda s: self.tokenizer.encode_special(s) + python_start = get_special("<|python_start|>") + python_end = get_special("<|python_end|>") + output_start = get_special("<|output_start|>") + output_end = get_special("<|output_end|>") + assistant_end = get_special("<|assistant_end|>") # if sampled, ends row + bos = self.tokenizer.get_bos_token_id() # if sampled, ends row + + # 1) Run a batch 1 prefill of the prompt tokens + m = self.model.config + kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} + kv_cache_prefill = KVCache( + batch_size=1, + seq_len=len(tokens), + **kv_model_kwargs, + ) + ids = torch.tensor([tokens], dtype=torch.long, device=device) + logits = self.model.forward(ids, kv_cache=kv_cache_prefill) + logits = logits[:, -1, :] + next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) + sampled_tokens = next_ids[:, 0].tolist() + + # 2) Replicate the KV cache for each sample/row + kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len + kv_cache_decode = KVCache( + batch_size=num_samples, + seq_len=kv_length_hint, + **kv_model_kwargs, + ) + kv_cache_decode.prefill(kv_cache_prefill) + del kv_cache_prefill # no need to keep this memory around + + # 3) Initialize states for each sample + row_states = [RowState(tokens.copy()) for _ in range(num_samples)] + + # 4) Main generation loop + num_generated = 0 + first_iteration = True + while True: + # Stop condition: we've reached max tokens + if max_tokens is not None and num_generated >= max_tokens: + break + # Stop condition: all rows are completed + if all(state.completed for state in row_states): + break + + # Get sampled tokens - either from prefill or from forward pass + if first_iteration: + # Use the tokens we already sampled from prefill + sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows + # TODO: we should sample a token for each row instead of broadcasting + first_iteration = False + else: + # Forward the model and get the next token for each row + logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size) + logits = logits[:, -1, :] # (B, vocab_size) at last time step + next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) + sampled_tokens = next_ids[:, 0].tolist() + + # Process each row: choose the next token, update state, optional tool use + token_column = [] # contains the next token id along each row + token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row + for i, state in enumerate(row_states): + # Select the next token in this row + is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque? + token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled + next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i] + token_column.append(next_token) + # Update the state of this row to include the next token + state.current_tokens.append(next_token) + # On <|assistant_end|> or <|bos|>, mark the row as completed + if next_token == assistant_end or next_token == bos: + state.completed = True + # Handle tool logic + if next_token == python_start: + state.in_python_block = True + state.python_expr_tokens = [] + elif next_token == python_end and state.in_python_block: + state.in_python_block = False + if state.python_expr_tokens: + expr = self.tokenizer.decode(state.python_expr_tokens) + result = use_calculator(expr) + if result is not None: + result_tokens = self.tokenizer.encode(str(result)) + state.forced_tokens.append(output_start) + state.forced_tokens.extend(result_tokens) + state.forced_tokens.append(output_end) + state.python_expr_tokens = [] + elif state.in_python_block: + state.python_expr_tokens.append(next_token) + + # Yield the token column + yield token_column, token_masks + num_generated += 1 + # Prepare ids for next iteration + ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) + + def generate_batch(self, tokens, num_samples=1, **kwargs): + """ + Non-streaming batch generation that just returns the final token sequences. + Returns a list of token sequences (list of lists of ints). + Terminal tokens (assistant_end, bos) are not included in the results. + """ + assistant_end = self.tokenizer.encode_special("<|assistant_end|>") + bos = self.tokenizer.get_bos_token_id() + results = [tokens.copy() for _ in range(num_samples)] + masks = [[0] * len(tokens) for _ in range(num_samples)] + completed = [False] * num_samples + for token_column, token_masks in self.generate(tokens, num_samples, **kwargs): + for i, (token, mask) in enumerate(zip(token_column, token_masks)): + if not completed[i]: + if token == assistant_end or token == bos: + completed[i] = True + else: + results[i].append(token) + masks[i].append(mask) + # Stop if all rows are completed + if all(completed): + break + return results, masks + + +if __name__ == "__main__": + """ + Quick inline test to make sure that the naive/slow model.generate function + is equivalent to the faster Engine.generate function here. + """ + import time + # init compute + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() + # load the model and tokenizer + model, tokenizer, meta = load_model("base", device, phase="eval") + bos_token_id = tokenizer.get_bos_token_id() + # common hyperparameters + kwargs = dict(max_tokens=64, temperature=0.0) + # set the starting prompt + prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id) + # generate the reference sequence using the model.generate() function + generated_tokens = [] + torch.cuda.synchronize() + t0 = time.time() + stream = model.generate(prompt_tokens, **kwargs) + for token in stream: + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) + print() + torch.cuda.synchronize() + t1 = time.time() + print(f"Reference time: {t1 - t0:.2f}s") + reference_ids = generated_tokens + # generate tokens with Engine + generated_tokens = [] + engine = Engine(model, tokenizer) + stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 + torch.cuda.synchronize() + t0 = time.time() + for token_column, token_masks in stream: + token = token_column[0] # only print out the first row + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) + print() + torch.cuda.synchronize() + t1 = time.time() + print(f"Engine time: {t1 - t0:.2f}s") + # compare the two sequences + for i in range(len(reference_ids)): + if reference_ids[i] != generated_tokens[i]: + print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}") + break + print(f"Match: {reference_ids == generated_tokens}") diff --git a/nanochat/execution.py b/nanochat/execution.py new file mode 100644 index 0000000..cda179d --- /dev/null +++ b/nanochat/execution.py @@ -0,0 +1,350 @@ +""" +Sandboxed execution utilities for running Python code that comes out of an LLM. +Adapted from OpenAI HumanEval code: +https://github.com/openai/human-eval/blob/master/human_eval/execution.py + +What is covered: +- Each execution runs in its own process (can be killed if it hangs or crashes) +- Execution is limited by a timeout to stop infinite loops +- Memory limits are enforced by default (256MB) +- stdout and stderr are captured and returned +- Code runs in a temporary directory that is deleted afterwards +- Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen) + +What is not covered: +- Not a true security sandbox +- Network access is not blocked (e.g. sockets could be opened) +- Python's dynamic features (e.g. ctypes) could bypass restrictions +- No kernel-level isolation (no seccomp, no containers, no virtualization) + +Overall this sandbox is good for evaluation of generated code and protects against +accidental destructive behavior, but it is not safe against malicious adversarial code. +""" + +import contextlib +import faulthandler +import io +import multiprocessing +import os +import platform +import signal +import tempfile +from dataclasses import dataclass +from typing import Optional + +# ----------------------------------------------------------------------------- + +@dataclass +class ExecutionResult: + """Result of executing Python code in a sandbox.""" + success: bool + stdout: str + stderr: str + error: Optional[str] = None + timeout: bool = False + memory_exceeded: bool = False + + def __repr__(self): + parts = [] + parts.append(f"ExecutionResult(success={self.success}") + if self.timeout: + parts.append(", timeout=True") + if self.memory_exceeded: + parts.append(", memory_exceeded=True") + if self.error: + parts.append(f", error={self.error!r}") + if self.stdout: + parts.append(f", stdout={self.stdout!r}") + if self.stderr: + parts.append(f", stderr={self.stderr!r}") + parts.append(")") + return "".join(parts) + + +@contextlib.contextmanager +def time_limit(seconds: float): + def signal_handler(signum, frame): + raise TimeoutException("Timed out!") + + signal.setitimer(signal.ITIMER_REAL, seconds) + signal.signal(signal.SIGALRM, signal_handler) + try: + yield + finally: + signal.setitimer(signal.ITIMER_REAL, 0) + + +@contextlib.contextmanager +def capture_io(): + """Capture stdout and stderr, and disable stdin.""" + stdout_capture = io.StringIO() + stderr_capture = io.StringIO() + stdin_block = WriteOnlyStringIO() + with contextlib.redirect_stdout(stdout_capture): + with contextlib.redirect_stderr(stderr_capture): + with redirect_stdin(stdin_block): + yield stdout_capture, stderr_capture + + +@contextlib.contextmanager +def create_tempdir(): + with tempfile.TemporaryDirectory() as dirname: + with chdir(dirname): + yield dirname + + +class TimeoutException(Exception): + pass + + +class WriteOnlyStringIO(io.StringIO): + """StringIO that throws an exception when it's read from""" + + def read(self, *args, **kwargs): + raise IOError + + def readline(self, *args, **kwargs): + raise IOError + + def readlines(self, *args, **kwargs): + raise IOError + + def readable(self, *args, **kwargs): + """Returns True if the IO object can be read.""" + return False + + +class redirect_stdin(contextlib._RedirectStream): # type: ignore + _stream = "stdin" + + +@contextlib.contextmanager +def chdir(root): + if root == ".": + yield + return + cwd = os.getcwd() + os.chdir(root) + try: + yield + except BaseException as exc: + raise exc + finally: + os.chdir(cwd) + + +def reliability_guard(maximum_memory_bytes: Optional[int] = None): + """ + This disables various destructive functions and prevents the generated code + from interfering with the test (e.g. fork bomb, killing other processes, + removing filesystem files, etc.) + + WARNING + This function is NOT a security sandbox. Untrusted code, including, model- + generated code, should not be blindly executed outside of one. See the + Codex paper for more information about OpenAI's code sandbox, and proceed + with caution. + """ + + if maximum_memory_bytes is not None: + import resource + + resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + if not platform.uname().system == "Darwin": + resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + + faulthandler.disable() + + import builtins + + builtins.exit = None + builtins.quit = None + + import os + + os.environ["OMP_NUM_THREADS"] = "1" + + os.kill = None + os.system = None + os.putenv = None + os.remove = None + os.removedirs = None + os.rmdir = None + os.fchdir = None + os.setuid = None + os.fork = None + os.forkpty = None + os.killpg = None + os.rename = None + os.renames = None + os.truncate = None + os.replace = None + os.unlink = None + os.fchmod = None + os.fchown = None + os.chmod = None + os.chown = None + os.chroot = None + os.fchdir = None + os.lchflags = None + os.lchmod = None + os.lchown = None + os.getcwd = None + os.chdir = None + + import shutil + + shutil.rmtree = None + shutil.move = None + shutil.chown = None + + import subprocess + + subprocess.Popen = None # type: ignore + + __builtins__["help"] = None + + import sys + + sys.modules["ipdb"] = None + sys.modules["joblib"] = None + sys.modules["resource"] = None + sys.modules["psutil"] = None + sys.modules["tkinter"] = None + + +def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict): + """Execute code in a subprocess with safety guards. Results are written to result_dict.""" + with create_tempdir(): + + # These system calls are needed when cleaning up tempdir. + import os + import shutil + + rmtree = shutil.rmtree + rmdir = os.rmdir + chdir = os.chdir + + # Disable functionalities that can make destructive changes to the test. + reliability_guard(maximum_memory_bytes=maximum_memory_bytes) + + # Default to failure + result_dict.update({ + "success": False, + "stdout": "", + "stderr": "", + "timeout": False, + "memory_exceeded": False, + "error": None, + }) + + try: + exec_globals = {} + with capture_io() as (stdout_capture, stderr_capture): + with time_limit(timeout): + # WARNING + # This program exists to execute untrusted model-generated code. Although + # it is highly unlikely that model-generated code will do something overtly + # malicious in response to this test suite, model-generated code may act + # destructively due to a lack of model capability or alignment. + # Users are strongly encouraged to sandbox this evaluation suite so that it + # does not perform destructive actions on their host or network. For more + # information on how OpenAI sandboxes its code, see the accompanying paper. + # Once you have read this disclaimer and taken appropriate precautions, + # uncomment the following line and proceed at your own risk: + exec(code, exec_globals) + + result_dict.update({ + "success": True, + "stdout": stdout_capture.getvalue(), + "stderr": stderr_capture.getvalue(), + }) + + except TimeoutException: + result_dict.update({ + "timeout": True, + "error": "Execution timed out", + }) + + except MemoryError as e: + result_dict.update({ + "memory_exceeded": True, + "error": f"Memory limit exceeded: {e}", + }) + + except BaseException as e: + result_dict.update({ + "error": f"{type(e).__name__}: {e}", + }) + + # Needed for cleaning up. + shutil.rmtree = rmtree + os.rmdir = rmdir + os.chdir = chdir + + +def execute_code( + code: str, + timeout: float = 5.0, # 5 seconds default + maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default +) -> ExecutionResult: + """ + Execute Python code in a sandboxed environment. + + Args: + code: Python code to execute as a string + timeout: Maximum execution time in seconds (default: 5.0) + maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable) + + Returns: + ExecutionResult with success status, stdout/stderr, and error information + + Example: + >>> result = execute_code("print('hello world')") + >>> result.success + True + >>> result.stdout + 'hello world\\n' + """ + + manager = multiprocessing.Manager() + result_dict = manager.dict() + + p = multiprocessing.Process( + target=_unsafe_execute, + args=(code, timeout, maximum_memory_bytes, result_dict) + ) + p.start() + p.join(timeout=timeout + 1) + + if p.is_alive(): + p.kill() + return ExecutionResult( + success=False, + stdout="", + stderr="", + error="Execution timed out (process killed)", + timeout=True, + memory_exceeded=False, + ) + + if not result_dict: + return ExecutionResult( + success=False, + stdout="", + stderr="", + error="Execution failed (no result returned)", + timeout=True, + memory_exceeded=False, + ) + + return ExecutionResult( + success=result_dict["success"], + stdout=result_dict["stdout"], + stderr=result_dict["stderr"], + error=result_dict["error"], + timeout=result_dict["timeout"], + memory_exceeded=result_dict["memory_exceeded"], + ) + diff --git a/nanochat/gpt.py b/nanochat/gpt.py new file mode 100644 index 0000000..5a066b2 --- /dev/null +++ b/nanochat/gpt.py @@ -0,0 +1,322 @@ +""" +GPT model (rewrite, a lot simpler) +Notable features: +- rotary embeddings (and no positional embeddings) +- QK norm +- untied weights for token embedding and lm_head +- relu^2 activation in MLP +- norm after token embedding +- no learnable params in rmsnorm +- no bias in linear layers +- Multi-Query Attention (MQA) support for more efficient inference +""" + +import math +from functools import partial +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nanochat.common import get_dist_info, print0 +from nanochat.muon import Muon, DistMuon +from nanochat.adamw import DistAdamW + +@dataclass +class GPTConfig: + sequence_len: int = 1024 + vocab_size: int = 50304 + n_layer: int = 12 + n_head: int = 6 # number of query heads + n_kv_head: int = 6 # number of key/value heads (MQA) + n_embd: int = 768 + + +def norm(x): + # Purely functional rmsnorm with no learnable params + return F.rms_norm(x, (x.size(-1),)) + + +def apply_rotary_emb(x, cos, sin): + assert x.ndim == 4 # multihead attention + d = x.shape[3] // 2 + x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves + y1 = x1 * cos + x2 * sin # rotate pairs of dims + y2 = x1 * (-sin) + x2 * cos + out = torch.cat([y1, y2], 3) # re-assemble + out = out.to(x.dtype) # ensure input/output dtypes match + return out + + +def repeat_kv(x, n_rep): + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + if n_rep == 1: + return x + bs, n_kv_heads, slen, head_dim = x.shape + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.layer_idx = layer_idx + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head + self.n_embd = config.n_embd + self.head_dim = self.n_embd // self.n_head + assert self.n_embd % self.n_head == 0 + assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 + self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) + self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + + def forward(self, x, cos_sin, kv_cache): + B, T, C = x.size() + + # Project the input to get queries, keys, and values + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + + # Apply Rotary Embeddings to queries and keys to get relative positional encoding + cos, sin = cos_sin + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding + q, k = norm(q), norm(k) # QK norm + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) + + # Apply KV cache: insert current k,v into cache, get the full view so far + if kv_cache is not None: + k, v = kv_cache.insert_kv(self.layer_idx, k, v) + Tq = q.size(2) # number of queries in this forward pass + Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) + + # Apply MQA: replicate the key/value heads for each query head + nrep = self.n_head // self.n_kv_head + k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) + + # Attention: queries attend to keys/values autoregressively. A few cases to handle: + if kv_cache is None or Tq == Tk: + # During training (no KV cache), attend as usual with causal attention + # And even if there is KV cache, we can still use this simple version when Tq == Tk + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + elif Tq == 1: + # During inference but with a single query in this forward pass: + # The query has to attend to all the keys/values in the cache + y = F.scaled_dot_product_attention(q, k, v, is_causal=False) + else: + # During inference AND we have a chunk of queries in this forward pass: + # First, each query attends to all the cached keys/values (i.e. full prefix) + attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask + prefix_len = Tk - Tq + if prefix_len > 0: # can't be negative but could be zero + attn_mask[:, :prefix_len] = True + # Then, causal attention within this chunk + attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + + # Re-assemble the heads side by side and project back to residual stream + y = y.transpose(1, 2).contiguous().view(B, T, -1) + y = self.c_proj(y) + return y + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + + def forward(self, x): + x = self.c_fc(x) + x = F.relu(x).square() + x = self.c_proj(x) + return x + + +class Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.attn = CausalSelfAttention(config, layer_idx) + self.mlp = MLP(config) + + def forward(self, x, cos_sin, kv_cache): + x = x + self.attn(norm(x), cos_sin, kv_cache) + x = x + self.mlp(norm(x)) + return x + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.transformer = nn.ModuleDict({ + "wte": nn.Embedding(config.vocab_size, config.n_embd), + "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), + }) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # To support meta device initialization, we init the rotary embeddings here, but it's fake + # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, + # so let's just over-compute them, but assert fail if we ever reach that amount. + # In the future we can dynamically grow the cache, for now it's fine. + self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? + head_dim = config.n_embd // config.n_head + cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) + self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint + self.register_buffer("sin", sin, persistent=False) + # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations + self.transformer.wte.to(dtype=torch.bfloat16) + + def init_weights(self): + self.apply(self._init_weights) + # zero out classifier weights + torch.nn.init.zeros_(self.lm_head.weight) + # zero out c_proj weights in all blocks + for block in self.transformer.h: + torch.nn.init.zeros_(block.mlp.c_proj.weight) + torch.nn.init.zeros_(block.attn.c_proj.weight) + # init the rotary embeddings + head_dim = self.config.n_embd // self.config.n_head + cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) + self.cos, self.sin = cos, sin + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + # https://arxiv.org/pdf/2310.17813 + fan_out = module.weight.size(0) + fan_in = module.weight.size(1) + std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) + + # TODO: bump base theta more, e.g. 100K is more common more recently + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): + # autodetect the device from model embeddings + if device is None: + device = self.transformer.wte.weight.device + # stride the channels + channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) + inv_freq = 1.0 / (base ** (channel_range / head_dim)) + # stride the time steps + t = torch.arange(seq_len, dtype=torch.float32, device=device) + # calculate the rotation frequencies at each (time, channel) pair + freqs = torch.outer(t, inv_freq) + cos, sin = freqs.cos(), freqs.sin() + cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting + return cos, sin + + def get_device(self): + return self.transformer.wte.weight.device + + def estimate_flops(self): + """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """ + nparams = sum(p.numel() for p in self.parameters()) + nparams_embedding = self.transformer.wte.weight.numel() + l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + return num_flops_per_token + + def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): + model_dim = self.config.n_embd + ddp, rank, local_rank, world_size = get_dist_info() + # Separate out all parameters into 3 groups (matrix, embedding, lm_head) + matrix_params = list(self.transformer.h.parameters()) + embedding_params = list(self.transformer.wte.parameters()) + lm_head_params = list(self.lm_head.parameters()) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + # Create the AdamW optimizer for the embedding and lm_head + # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) + dmodel_lr_scale = (model_dim / 768) ** -0.5 + if rank == 0: + print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") + adam_groups = [ + dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), + dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), + ] + adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay) + AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) + adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) + # Create the Muon optimizer for the linear layers + muon_kwargs = dict(lr=matrix_lr, momentum=0.95) + MuonFactory = DistMuon if ddp else Muon + muon_optimizer = MuonFactory(matrix_params, **muon_kwargs) + # Combine them the two optimizers into one list + optimizers = [adamw_optimizer, muon_optimizer] + for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + return optimizers + + def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): + B, T = idx.size() + + # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) + assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" + assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" + # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache + T0 = 0 if kv_cache is None else kv_cache.get_pos() + cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length + + # Forward the trunk of the Transformer + x = self.transformer.wte(idx) + x = norm(x) + for block in self.transformer.h: + x = block(x, cos_sin, kv_cache) + x = norm(x) + + # Forward the lm_head (compute logits) + softcap = 15 + if targets is not None: + # training mode: compute and return the loss + # TODO: experiment with Liger Kernels / chunked cross-entropy etc. + logits = self.lm_head(x) + logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = logits.float() # use tf32/fp32 for logits + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) + return loss + else: + # inference mode: compute and return the logits + logits = self.lm_head(x) + logits = softcap * torch.tanh(logits / softcap) # logits softcap + return logits + + @torch.inference_mode() + def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): + """ + Naive autoregressive streaming inference. + To make it super simple, let's assume: + - batch size is 1 + - ids and the yielded tokens are simple Python lists and ints + """ + assert isinstance(tokens, list) + device = self.get_device() + rng = None + if temperature > 0: + rng = torch.Generator(device=device) + rng.manual_seed(seed) + ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim + for _ in range(max_tokens): + logits = self.forward(ids) # (B, T, vocab_size) + logits = logits[:, -1, :] # (B, vocab_size) + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + if temperature > 0: + logits = logits / temperature + probs = F.softmax(logits, dim=-1) + next_ids = torch.multinomial(probs, num_samples=1, generator=rng) + else: + next_ids = torch.argmax(logits, dim=-1, keepdim=True) + ids = torch.cat((ids, next_ids), dim=1) + token = next_ids.item() + yield token diff --git a/nanochat/logo.svg b/nanochat/logo.svg new file mode 100644 index 0000000..f0f0571 --- /dev/null +++ b/nanochat/logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py new file mode 100644 index 0000000..d103ef6 --- /dev/null +++ b/nanochat/loss_eval.py @@ -0,0 +1,63 @@ +""" +A number of functions that help with evaluating a base model. +""" +import math +import torch +import torch.distributed as dist + +@torch.no_grad() +def evaluate_bpb(model, batches, steps, token_bytes): + """ + Instead of the naive 'mean loss', this function returns the bits per byte (bpb), + which is a tokenization vocab size-indepedent metric, meaning you are still comparing + apples:apples if you change the vocab size. The way this works is that instead of just + calculating the average loss as usual, you calculate the sum loss, and indepependently + also the sum bytes (of all the target tokens), and divide. This normalizes the loss by + the number of bytes that the target tokens represent. + + The added complexity is so that: + 1) All "normal" tokens are normalized by the length of the token in bytes + 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out. + 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric. + + In addition to evaluate_loss, we need the token_bytes tensor: + It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for + each token id, or 0 if the token is to not be counted (e.g. special tokens). + """ + # record the losses + total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) + total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) + batch_iter = iter(batches) + for _ in range(steps): + x, y = next(batch_iter) + loss2d = model(x, y, loss_reduction='none') # (B, T) + loss2d = loss2d.view(-1) # flatten + y = y.view(-1) # flatten + if (y < 0).any(): + # slightly more complex code path if some target tokens are ignore_index (e.g. -1) + # any target token < 0 is to be ignored: do NOT index token_bytes with negatives + valid = y >= 0 + y_safe = torch.where(valid, y, torch.zeros_like(y)) + # map valid targets to their byte length; ignored targets contribute 0 bytes + num_bytes2d = torch.where( + valid, + token_bytes[y_safe], + torch.zeros_like(y, dtype=token_bytes.dtype) + ) + total_nats += (loss2d * (num_bytes2d > 0)).sum() + total_bytes += num_bytes2d.sum() + else: + # fast path: no ignored targets, safe to index directly + num_bytes2d = token_bytes[y] + total_nats += (loss2d * (num_bytes2d > 0)).sum() + total_bytes += num_bytes2d.sum() + # sum reduce across all ranks + world_size = dist.get_world_size() if dist.is_initialized() else 1 + if world_size > 1: + dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) + dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) + # move both to cpu, calculate bpb and return + total_nats = total_nats.item() + total_bytes = total_bytes.item() + bpb = total_nats / (math.log(2) * total_bytes) + return bpb diff --git a/nanochat/muon.py b/nanochat/muon.py new file mode 100644 index 0000000..d916103 --- /dev/null +++ b/nanochat/muon.py @@ -0,0 +1,187 @@ +""" +Muon optimizer from Keller et al. +Also a lot of borrowing of ideas from modded-nanogpt. +""" +import torch +from torch import Tensor +import torch.distributed as dist + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + + Arguments: + lr: The learning rate used by the internal SGD. + momentum: The momentum used by the internal SGD. + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iteration steps to use. + """ + def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + params: list[Tensor] = [*params] + param_groups = [] + for size in {p.numel() for p in params}: + group = dict(params=[p for p in params if p.numel() == size]) + param_groups.append(group) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + params: list[Tensor] = group["params"] + for p in params: + g = p.grad + assert g is not None + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf: Tensor = state["momentum_buffer"] + buf.lerp_(g, 1 - group["momentum"]) + g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5) + + +class DistMuon(torch.optim.Optimizer): + """ + Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz, + finally apply aspect-ratio scaled step. Performs its own distributed synchronization: + - reduce_scatter(AVG) for gradient averaging + - all_gather to replicate updated weights + + Notes: + * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D + params like embeddings or scalars. + * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen + by block-cyclic assignment below). If you checkpoint optimizer state on a single rank, + consolidate states beforehand. + + Args: + params: iterable of Tensors + lr: learning rate + momentum: momentum coefficient in [0,1) + nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf + ns_steps: number of Newton–Schulz iterations for the orthogonalization + """ + def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, + nesterov: bool = True, ns_steps: int = 5): + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + params = list(params) + assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" + rank = dist.get_rank() + # Group all parameters by their shape + shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering + param_groups = [] + for shape in shapes: + group_params = [p for p in params if p.shape == shape] + device, dtype = group_params[0].device, group_params[0].dtype + assert all(p.device == device for p in group_params) + assert all(p.dtype == dtype for p in group_params) + if rank == 0: + print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}") + param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Ensure all grads exist + assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" + + # Kick off all the reduce scatter operations to average up the gradients across all ranks + all_reduce_futures = [] + for group in self.param_groups: + params = group["params"] + zero_buffer = group["zero_buffer"] + # Go through params in groups of world_size. + for base_i in range(0, len(params), world_size): + # The compute owner of each param is rank i % world_size + owner_idx = base_i + rank + # each rank stacks up its chunk of world_size params into a list + rs_input = [p.grad for p in params[base_i:base_i + world_size]] + # pad rs_input with the zero buffer to complete the group + rs_input.extend([zero_buffer] * (world_size - len(rs_input))) + # the output buffer gets strided across the group based on the rank + rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer) + # reduce scatter the gradients within this group of world_size params + work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future() + all_reduce_futures.append(work) + + # Now each rank computes the update and gathers + future_idx = 0 + all_gather_futures = [] + for group in self.param_groups: + params = group["params"] + zero_buffer = group["zero_buffer"] + # Go through params in groups of world_size. + for base_i in range(0, len(params), world_size): + # The compute owner of each param is rank i % world_size + owner_idx = base_i + rank # calculate the index of the param that this rank owns + # Wait for the reduce scatter to complete + all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead + future_idx += 1 + # Owner computes the Muon update, result is in its param + if owner_idx < len(params): + p = params[owner_idx] + g = p.grad # now averaged across ranks + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf: Tensor = state["momentum_buffer"] + buf.lerp_(g, 1.0 - group["momentum"]) + g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) + p.add_(g, alpha=-group["lr"] * scale) + # Replicate updated parameters to all ranks + ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer + ag_output = params[base_i:base_i + world_size] + ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad + work = dist.all_gather(ag_output, ag_input, async_op=True).get_future() + all_gather_futures.append(work) + + # Wait for all work to finish + torch.futures.collect_all(all_gather_futures).wait() diff --git a/nanochat/report.py b/nanochat/report.py new file mode 100644 index 0000000..02cd8b0 --- /dev/null +++ b/nanochat/report.py @@ -0,0 +1,404 @@ +""" +Utilities for generating training report cards. More messy code than usual, will fix. +""" + +import os +import re +import shutil +import subprocess +import socket +import datetime +import platform +import psutil +import torch + +def run_command(cmd): + """Run a shell command and return output, or None if it fails.""" + try: + result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5) + if result.returncode == 0: + return result.stdout.strip() + return None + except: + return None + +def get_git_info(): + """Get current git commit, branch, and dirty status.""" + info = {} + info['commit'] = run_command("git rev-parse --short HEAD") or "unknown" + info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown" + + # Check if repo is dirty (has uncommitted changes) + status = run_command("git status --porcelain") + info['dirty'] = bool(status) if status is not None else False + + # Get commit message + info['message'] = run_command("git log -1 --pretty=%B") or "" + info['message'] = info['message'].split('\n')[0][:80] # First line, truncated + + return info + +def get_gpu_info(): + """Get GPU information.""" + if not torch.cuda.is_available(): + return {"available": False} + + num_devices = torch.cuda.device_count() + info = { + "available": True, + "count": num_devices, + "names": [], + "memory_gb": [] + } + + for i in range(num_devices): + props = torch.cuda.get_device_properties(i) + info["names"].append(props.name) + info["memory_gb"].append(props.total_memory / (1024**3)) + + # Get CUDA version + info["cuda_version"] = torch.version.cuda or "unknown" + + return info + +def get_system_info(): + """Get system information.""" + info = {} + + # Basic system info + info['hostname'] = socket.gethostname() + info['platform'] = platform.system() + info['python_version'] = platform.python_version() + info['torch_version'] = torch.__version__ + + # CPU and memory + info['cpu_count'] = psutil.cpu_count(logical=False) + info['cpu_count_logical'] = psutil.cpu_count(logical=True) + info['memory_gb'] = psutil.virtual_memory().total / (1024**3) + + # User and environment + info['user'] = os.environ.get('USER', 'unknown') + info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out') + info['working_dir'] = os.getcwd() + + return info + +def estimate_cost(gpu_info, runtime_hours=None): + """Estimate training cost based on GPU type and runtime.""" + + # Rough pricing, from Lambda Cloud + default_rate = 2.0 + gpu_hourly_rates = { + "H100": 3.00, + "A100": 1.79, + "V100": 0.55, + } + + if not gpu_info.get("available"): + return None + + # Try to identify GPU type from name + hourly_rate = None + gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown" + for gpu_type, rate in gpu_hourly_rates.items(): + if gpu_type in gpu_name: + hourly_rate = rate * gpu_info["count"] + break + + if hourly_rate is None: + hourly_rate = default_rate * gpu_info["count"] # Default estimate + + return { + "hourly_rate": hourly_rate, + "gpu_type": gpu_name, + "estimated_total": hourly_rate * runtime_hours if runtime_hours else None + } + +def generate_header(): + """Generate the header for a training report.""" + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + git_info = get_git_info() + gpu_info = get_gpu_info() + sys_info = get_system_info() + cost_info = estimate_cost(gpu_info) + + header = f"""# nanochat training report + +Generated: {timestamp} + +## Environment + +### Git Information +- Branch: {git_info['branch']} +- Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"} +- Message: {git_info['message']} + +### Hardware +- Platform: {sys_info['platform']} +- CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical) +- Memory: {sys_info['memory_gb']:.1f} GB +""" + + if gpu_info.get("available"): + gpu_names = ", ".join(set(gpu_info["names"])) + total_vram = sum(gpu_info["memory_gb"]) + header += f"""- GPUs: {gpu_info['count']}x {gpu_names} +- GPU Memory: {total_vram:.1f} GB total +- CUDA Version: {gpu_info['cuda_version']} +""" + else: + header += "- GPUs: None available\n" + + if cost_info and cost_info["hourly_rate"] > 0: + header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n""" + + header += f""" +### Software +- Python: {sys_info['python_version']} +- PyTorch: {sys_info['torch_version']} + +""" + + # bloat metrics: package all of the source code and assess its weight + packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml') + num_chars = len(packaged) + num_lines = len(packaged.split('\n')) + num_files = len([x for x in packaged.split('\n') if x.startswith('')]) + num_tokens = num_chars // 4 # assume approximately 4 chars per token + + # count dependencies via uv.lock + uv_lock_lines = 0 + if os.path.exists('uv.lock'): + with open('uv.lock', 'r') as f: + uv_lock_lines = len(f.readlines()) + + header += f""" +### Bloat +- Characters: {num_chars:,} +- Lines: {num_lines:,} +- Files: {num_files:,} +- Tokens (approx): {num_tokens:,} +- Dependencies (uv.lock lines): {uv_lock_lines:,} + +""" + return header + +# ----------------------------------------------------------------------------- + +def slugify(text): + """Slugify a text string.""" + return text.lower().replace(" ", "-") + +# the expected files and their order +EXPECTED_FILES = [ + "tokenizer-training.md", + "tokenizer-evaluation.md", + "base-model-training.md", + "base-model-loss.md", + "base-model-evaluation.md", + "midtraining.md", + "chat-evaluation-mid.md", + "chat-sft.md", + "chat-evaluation-sft.md", + "chat-rl.md", + "chat-evaluation-rl.md", +] +# the metrics we're currently interested in +chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"] + +def extract(section, keys): + """simple def to extract a single key from a section""" + if not isinstance(keys, list): + keys = [keys] # convenience + out = {} + for line in section.split("\n"): + for key in keys: + if key in line: + out[key] = line.split(":")[1].strip() + return out + +def extract_timestamp(content, prefix): + """Extract timestamp from content with given prefix.""" + for line in content.split('\n'): + if line.startswith(prefix): + time_str = line.split(":", 1)[1].strip() + try: + return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") + except: + pass + return None + +class Report: + """Maintains a bunch of logs, generates a final markdown report.""" + + def __init__(self, report_dir): + os.makedirs(report_dir, exist_ok=True) + self.report_dir = report_dir + + def log(self, section, data): + """Log a section of data to the report.""" + slug = slugify(section) + file_name = f"{slug}.md" + file_path = os.path.join(self.report_dir, file_name) + with open(file_path, "w") as f: + f.write(f"## {section}\n") + f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") + for item in data: + if not item: + # skip falsy values like None or empty dict etc. + continue + if isinstance(item, str): + # directly write the string + f.write(item) + else: + # render a dict + for k, v in item.items(): + if isinstance(v, float): + vstr = f"{v:.4f}" + elif isinstance(v, int) and v >= 10000: + vstr = f"{v:,.0f}" + else: + vstr = str(v) + f.write(f"- {k}: {vstr}\n") + f.write("\n") + return file_path + + def generate(self): + """Generate the final report.""" + report_dir = self.report_dir + report_file = os.path.join(report_dir, "report.md") + print(f"Generating report to {report_file}") + final_metrics = {} # the most important final metrics we'll add as table at the end + start_time = None + end_time = None + with open(report_file, "w") as out_file: + # write the header first + header_file = os.path.join(report_dir, "header.md") + if os.path.exists(header_file): + with open(header_file, "r") as f: + header_content = f.read() + out_file.write(header_content) + start_time = extract_timestamp(header_content, "Run started:") + # capture bloat data for summary later (the stuff after Bloat header and until \n\n) + bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL) + bloat_data = bloat_data.group(1) if bloat_data else "" + # process all the individual sections + for file_name in EXPECTED_FILES: + section_file = os.path.join(report_dir, file_name) + if not os.path.exists(section_file): + print(f"Warning: {section_file} does not exist, skipping") + continue + with open(section_file, "r") as in_file: + section = in_file.read() + # Extract timestamp from this section (the last section's timestamp will "stick" as end_time) + if "rl" not in file_name: + # Skip RL sections for end_time calculation because RL is experimental + end_time = extract_timestamp(section, "timestamp:") + # extract the most important metrics from the sections + if file_name == "base-model-evaluation.md": + final_metrics["base"] = extract(section, "CORE") + if file_name == "chat-evaluation-mid.md": + final_metrics["mid"] = extract(section, chat_metrics) + if file_name == "chat-evaluation-sft.md": + final_metrics["sft"] = extract(section, chat_metrics) + if file_name == "chat-evaluation-rl.md": + final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K + # append this section of the report + out_file.write(section) + out_file.write("\n") + # add the final metrics table + out_file.write("## Summary\n\n") + # Copy over the bloat metrics from the header + out_file.write(bloat_data) + out_file.write("\n\n") + # Collect all unique metric names + all_metrics = set() + for stage_metrics in final_metrics.values(): + all_metrics.update(stage_metrics.keys()) + # Custom ordering: CORE first, ChatCORE last, rest in middle + all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x)) + # Fixed column widths + stages = ["base", "mid", "sft", "rl"] + metric_width = 15 + value_width = 8 + # Write table header + header = f"| {'Metric'.ljust(metric_width)} |" + for stage in stages: + header += f" {stage.upper().ljust(value_width)} |" + out_file.write(header + "\n") + # Write separator + separator = f"|{'-' * (metric_width + 2)}|" + for stage in stages: + separator += f"{'-' * (value_width + 2)}|" + out_file.write(separator + "\n") + # Write table rows + for metric in all_metrics: + row = f"| {metric.ljust(metric_width)} |" + for stage in stages: + value = final_metrics.get(stage, {}).get(metric, "-") + row += f" {str(value).ljust(value_width)} |" + out_file.write(row + "\n") + out_file.write("\n") + # Calculate and write total wall clock time + if start_time and end_time: + duration = end_time - start_time + total_seconds = int(duration.total_seconds()) + hours = total_seconds // 3600 + minutes = (total_seconds % 3600) // 60 + out_file.write(f"Total wall clock time: {hours}h{minutes}m\n") + else: + out_file.write("Total wall clock time: unknown\n") + # also cp the report.md file to current directory + print(f"Copying report.md to current directory for convenience") + shutil.copy(report_file, "report.md") + return report_file + + def reset(self): + """Reset the report.""" + # Remove section files + for file_name in EXPECTED_FILES: + file_path = os.path.join(self.report_dir, file_name) + if os.path.exists(file_path): + os.remove(file_path) + # Remove report.md if it exists + report_file = os.path.join(self.report_dir, "report.md") + if os.path.exists(report_file): + os.remove(report_file) + # Generate and write the header section with start timestamp + header_file = os.path.join(self.report_dir, "header.md") + header = generate_header() + start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with open(header_file, "w") as f: + f.write(header) + f.write(f"Run started: {start_time}\n\n---\n\n") + print(f"Reset report and wrote header to {header_file}") + +# ----------------------------------------------------------------------------- +# nanochat-specific convenience functions + +class DummyReport: + def log(self, *args, **kwargs): + pass + def reset(self, *args, **kwargs): + pass + +def get_report(): + # just for convenience, only rank 0 logs to report + from nanochat.common import get_base_dir, get_dist_info + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + if ddp_rank == 0: + report_dir = os.path.join(get_base_dir(), "report") + return Report(report_dir) + else: + return DummyReport() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.") + parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)") + args = parser.parse_args() + if args.command == "generate": + get_report().generate() + elif args.command == "reset": + get_report().reset() diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py new file mode 100644 index 0000000..68cd436 --- /dev/null +++ b/nanochat/tokenizer.py @@ -0,0 +1,395 @@ +""" +BPE Tokenizer in the style of GPT-4. + +Two implementations are available: +1) HuggingFace Tokenizer that can do both training and inference but is really confusing +2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference +""" + +import os +import copy +from functools import lru_cache + +SPECIAL_TOKENS = [ + # every document begins with the Beginning of Sequence (BOS) token that delimits documents + "<|bos|>", + # tokens below are only used during finetuning to render Conversations into token ids + "<|user_start|>", # user messages + "<|user_end|>", + "<|assistant_start|>", # assistant messages + "<|assistant_end|>", + "<|python_start|>", # assistant invokes python REPL tool + "<|python_end|>", + "<|output_start|>", # python REPL outputs back to assistant + "<|output_end|>", +] + +# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3} +# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes. +# I haven't validated that this is actually a good idea, TODO. +SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" + +# ----------------------------------------------------------------------------- +# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer +from tokenizers import Tokenizer as HFTokenizer +from tokenizers import pre_tokenizers, decoders, Regex +from tokenizers.models import BPE +from tokenizers.trainers import BpeTrainer + +class HuggingFaceTokenizer: + """Light wrapper around HuggingFace Tokenizer for some utilities""" + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + @classmethod + def from_pretrained(cls, hf_path): + # init from a HuggingFace pretrained tokenizer (e.g. "gpt2") + tokenizer = HFTokenizer.from_pretrained(hf_path) + return cls(tokenizer) + + @classmethod + def from_directory(cls, tokenizer_dir): + # init from a local directory on disk (e.g. "out/tokenizer") + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + tokenizer = HFTokenizer.from_file(tokenizer_path) + return cls(tokenizer) + + @classmethod + def train_from_iterator(cls, text_iterator, vocab_size): + # train from an iterator of text + # Configure the HuggingFace Tokenizer + tokenizer = HFTokenizer(BPE( + byte_fallback=True, # needed! + unk_token=None, + fuse_unk=False, + )) + # Normalizer: None + tokenizer.normalizer = None + # Pre-tokenizer: GPT-4 style + # the regex pattern used by GPT-4 to split text into groups before BPE + # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to + # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space. + # (but I haven't validated this! TODO) + gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! + tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ + pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) + ]) + # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) + tokenizer.decoder = decoders.ByteLevel() + # Post-processor: None + tokenizer.post_processor = None + # Trainer: BPE + trainer = BpeTrainer( + vocab_size=vocab_size, + show_progress=True, + min_frequency=0, # no minimum frequency + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + special_tokens=SPECIAL_TOKENS, + ) + # Kick off the training + tokenizer.train_from_iterator(text_iterator, trainer) + return cls(tokenizer) + + def get_vocab_size(self): + return self.tokenizer.get_vocab_size() + + def get_special_tokens(self): + special_tokens_map = self.tokenizer.get_added_tokens_decoder() + special_tokens = [w.content for w in special_tokens_map.values()] + return special_tokens + + def id_to_token(self, id): + return self.tokenizer.id_to_token(id) + + def _encode_one(self, text, prepend=None, append=None): + # encode a single string + # prepend/append can be either a string of a special token or a token id directly. + assert isinstance(text, str) + ids = [] + if prepend is not None: + prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend) + ids.append(prepend_id) + ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids) + if append is not None: + append_id = append if isinstance(append, int) else self.encode_special(append) + ids.append(append_id) + return ids + + def encode_special(self, text): + # encode a single special token via exact match + return self.tokenizer.token_to_id(text) + + def get_bos_token_id(self): + bos = self.encode_special("<|bos|>") + return bos + + def encode(self, text, *args, **kwargs): + if isinstance(text, str): + return self._encode_one(text, *args, **kwargs) + elif isinstance(text, list): + return [self._encode_one(t, *args, **kwargs) for t in text] + else: + raise ValueError(f"Invalid input type: {type(text)}") + + def __call__(self, *args, **kwargs): + return self.encode(*args, **kwargs) + + def decode(self, ids): + return self.tokenizer.decode(ids, skip_special_tokens=False) + + def save(self, tokenizer_dir): + # save the tokenizer to disk + os.makedirs(tokenizer_dir, exist_ok=True) + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + self.tokenizer.save(tokenizer_path) + print(f"Saved tokenizer to {tokenizer_path}") + +# ----------------------------------------------------------------------------- +# Tokenizer based on rustbpe + tiktoken combo +import pickle +import rustbpe +import tiktoken + +class RustBPETokenizer: + """Light wrapper around tiktoken (for efficient inference) but train with rustbpe""" + + def __init__(self, enc, bos_token): + self.enc = enc + self.bos_token_id = self.encode_special(bos_token) + + @classmethod + def train_from_iterator(cls, text_iterator, vocab_size): + # 1) train using rustbpe + tokenizer = rustbpe.Tokenizer() + # the special tokens are inserted later in __init__, we don't train them here + vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS) + assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}" + tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN) + # 2) construct the associated tiktoken encoding for inference + pattern = tokenizer.get_pattern() + mergeable_ranks_list = tokenizer.get_mergeable_ranks() + mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list} + tokens_offset = len(mergeable_ranks) + special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)} + enc = tiktoken.Encoding( + name="rustbpe", + pat_str=pattern, + mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank) + special_tokens=special_tokens, # dict[str, int] (special token name -> token id) + ) + return cls(enc, "<|bos|>") + + @classmethod + def from_directory(cls, tokenizer_dir): + pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl") + with open(pickle_path, "rb") as f: + enc = pickle.load(f) + return cls(enc, "<|bos|>") + + @classmethod + def from_pretrained(cls, tiktoken_name): + # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py + enc = tiktoken.get_encoding(tiktoken_name) + # tiktoken calls the special document delimiter token "<|endoftext|>" + # yes this is confusing because this token is almost always PREPENDED to the beginning of the document + # it most often is used to signal the start of a new sequence to the LLM during inference etc. + # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>". + return cls(enc, "<|endoftext|>") + + def get_vocab_size(self): + return self.enc.n_vocab + + def get_special_tokens(self): + return self.enc.special_tokens_set + + def id_to_token(self, id): + return self.enc.decode([id]) + + @lru_cache(maxsize=32) + def encode_special(self, text): + return self.enc.encode_single_token(text) + + def get_bos_token_id(self): + return self.bos_token_id + + def encode(self, text, prepend=None, append=None, num_threads=8): + # text can be either a string or a list of strings + + if prepend is not None: + prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend) + if append is not None: + append_id = append if isinstance(append, int) else self.encode_special(append) + + if isinstance(text, str): + ids = self.enc.encode_ordinary(text) + if prepend is not None: + ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm + if append is not None: + ids.append(append_id) + elif isinstance(text, list): + ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) + if prepend is not None: + for ids_row in ids: + ids_row.insert(0, prepend_id) # TODO: same + if append is not None: + for ids_row in ids: + ids_row.append(append_id) + else: + raise ValueError(f"Invalid input type: {type(text)}") + + return ids + + def __call__(self, *args, **kwargs): + return self.encode(*args, **kwargs) + + def decode(self, ids): + return self.enc.decode(ids) + + def save(self, tokenizer_dir): + # save the encoding object to disk + os.makedirs(tokenizer_dir, exist_ok=True) + pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl") + with open(pickle_path, "wb") as f: + pickle.dump(self.enc, f) + print(f"Saved tokenizer encoding to {pickle_path}") + + def render_conversation(self, conversation, max_tokens=2048): + """ + Tokenize a single Chat conversation (which we call a "doc" or "document" here). + Returns: + - ids: list[int] is a list of token ids of this rendered conversation + - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on. + """ + # ids, masks that we will return and a helper function to help build them up. + ids, mask = [], [] + def add_tokens(token_ids, mask_val): + if isinstance(token_ids, int): + token_ids = [token_ids] + ids.extend(token_ids) + mask.extend([mask_val] * len(token_ids)) + + # sometimes the first message is a system message... + # => just merge it with the second (user) message + if conversation["messages"][0]["role"] == "system": + # some conversation surgery is necessary here for now... + conversation = copy.deepcopy(conversation) # avoid mutating the original + messages = conversation["messages"] + assert messages[1]["role"] == "user", "System message must be followed by a user message" + messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"] + messages = messages[1:] + else: + messages = conversation["messages"] + assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}" + + # fetch all the special tokens we need + bos = self.get_bos_token_id() + user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>") + assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>") + python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>") + output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>") + + # now we can tokenize the conversation + add_tokens(bos, 0) + for i, message in enumerate(messages): + + # some sanity checking here around assumptions, to prevent footguns + must_be_from = "user" if i % 2 == 0 else "assistant" + assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}" + + # content can be either a simple string or a list of parts (e.g. containing tool calls) + content = message["content"] + + if message["role"] == "user": + assert isinstance(content, str), "User messages are simply expected to be strings" + value_ids = self.encode(content) + add_tokens(user_start, 0) + add_tokens(value_ids, 0) + add_tokens(user_end, 0) + elif message["role"] == "assistant": + add_tokens(assistant_start, 0) + if isinstance(content, str): + # simple string => simply add the tokens + value_ids = self.encode(content) + add_tokens(value_ids, 1) + elif isinstance(content, list): + for part in content: + value_ids = self.encode(part["text"]) + if part["type"] == "text": + # string part => simply add the tokens + add_tokens(value_ids, 1) + elif part["type"] == "python": + # python tool call => add the tokens inside <|python_start|> and <|python_end|> + add_tokens(python_start, 1) + add_tokens(value_ids, 1) + add_tokens(python_end, 1) + elif part["type"] == "python_output": + # python output => add the tokens inside <|output_start|> and <|output_end|> + # none of these tokens are supervised because the tokens come from Python at test time + add_tokens(output_start, 0) + add_tokens(value_ids, 0) + add_tokens(output_end, 0) + else: + raise ValueError(f"Unknown part type: {part['type']}") + else: + raise ValueError(f"Unknown content type: {type(content)}") + add_tokens(assistant_end, 1) + + # truncate to max_tokens tokens MAX (helps prevent OOMs) + ids = ids[:max_tokens] + mask = mask[:max_tokens] + return ids, mask + + def visualize_tokenization(self, ids, mask): + """Small helper function useful in debugging: visualize the tokenization of render_conversation""" + RED = '\033[91m' + GREEN = '\033[92m' + RESET = '\033[0m' + tokens = [] + for i, (token_id, mask_val) in enumerate(zip(ids, mask)): + token_str = self.decode([token_id]) + color = GREEN if mask_val == 1 else RED + tokens.append(f"{color}{token_str}{RESET}") + return '|'.join(tokens) + + def render_for_completion(self, conversation): + """ + Used during Reinforcement Learning. In that setting, we want to + render the conversation priming the Assistant for a completion. + Unlike the Chat SFT case, we don't need to return the mask. + """ + # We have some surgery to do: we need to pop the last message (of the Assistant) + conversation = copy.deepcopy(conversation) # avoid mutating the original + messages = conversation["messages"] + assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant" + messages.pop() # remove the last message (of the Assistant) inplace + + # Now tokenize the conversation + ids, mask = self.render_conversation(conversation) + + # Finally, to prime the Assistant for a completion, append the Assistant start token + assistant_start = self.encode_special("<|assistant_start|>") + ids.append(assistant_start) + return ids + +# ----------------------------------------------------------------------------- +# nanochat-specific convenience functions + +def get_tokenizer(): + from nanochat.common import get_base_dir + base_dir = get_base_dir() + tokenizer_dir = os.path.join(base_dir, "tokenizer") + # return HuggingFaceTokenizer.from_directory(tokenizer_dir) + return RustBPETokenizer.from_directory(tokenizer_dir) + +def get_token_bytes(device="cpu"): + import torch + from nanochat.common import get_base_dir + base_dir = get_base_dir() + tokenizer_dir = os.path.join(base_dir, "tokenizer") + token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") + assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py" + with open(token_bytes_path, "rb") as f: + token_bytes = torch.load(f, map_location=device) + return token_bytes diff --git a/nanochat/ui.html b/nanochat/ui.html new file mode 100644 index 0000000..39e608f --- /dev/null +++ b/nanochat/ui.html @@ -0,0 +1,394 @@ + + + + + + NanoChat + + + + +
+
+ +

nanochat

+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+ + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ef3833a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,55 @@ +[project] +name = "nanochat" +version = "0.1.0" +description = "the minimal full-stack ChatGPT clone" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "datasets>=4.0.0", + "fastapi>=0.117.1", + "files-to-prompt>=0.6", + "numpy==1.26.4", + "psutil>=7.1.0", + "regex>=2025.9.1", + "tiktoken>=0.11.0", + "tokenizers>=0.22.0", + "torch>=2.8.0", + "uvicorn>=0.36.0", + "wandb>=0.21.3", +] + +[build-system] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +# target torch to cuda 12.8 +[tool.uv.sources] +torch = [ + { index = "pytorch-cu128" }, +] + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + +[tool.maturin] +module-name = "rustbpe" +bindings = "pyo3" +python-source = "." +manifest-path = "rustbpe/Cargo.toml" + +[dependency-groups] +dev = [ + "maturin>=1.9.4", + "pytest>=8.0.0", +] + +[tool.pytest.ini_options] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", +] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] diff --git a/rustbpe/Cargo.lock b/rustbpe/Cargo.lock new file mode 100644 index 0000000..69f8754 --- /dev/null +++ b/rustbpe/Cargo.lock @@ -0,0 +1,458 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + +[[package]] +name = "cfg-if" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" + +[[package]] +name = "compact_str" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "static_assertions", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "dary_heap" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "fancy-regex" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c" +dependencies = [ + "bit-set", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indexmap" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "libc" +version = "0.2.175" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" + +[[package]] +name = "log" +version = "0.4.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" + +[[package]] +name = "memchr" +version = "2.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "proc-macro2" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + +[[package]] +name = "pyo3-macros" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex-automata" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" + +[[package]] +name = "rustbpe" +version = "0.1.0" +dependencies = [ + "ahash", + "compact_str", + "dary_heap", + "fancy-regex", + "indexmap", + "log", + "pyo3", + "pyo3-log", + "rayon", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.14.4+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wit-bindgen" +version = "0.45.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36" + +[[package]] +name = "zerocopy" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/rustbpe/Cargo.toml b/rustbpe/Cargo.toml new file mode 100644 index 0000000..392a828 --- /dev/null +++ b/rustbpe/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "rustbpe" +version = "0.1.0" +edition = "2024" + +[dependencies] +dary_heap = "0.3" +indexmap = "2.2" +fancy-regex = "0.16.1" +log = "0.4.28" +pyo3 = { version = "0.23.3", features = ["extension-module"] } +pyo3-log = "0.12.4" +ahash = "0.8.12" +rayon = "1.11.0" +compact_str = "0.9.0" diff --git a/rustbpe/README.md b/rustbpe/README.md new file mode 100644 index 0000000..c88636c --- /dev/null +++ b/rustbpe/README.md @@ -0,0 +1,5 @@ +# rustbpe + +> The missing tiktoken training code + +A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat. diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs new file mode 100644 index 0000000..b43fb6c --- /dev/null +++ b/rustbpe/src/lib.rs @@ -0,0 +1,476 @@ +use std::cmp::Ordering; +use std::collections::HashMap as StdHashMap; + +use dary_heap::OctonaryHeap; +use fancy_regex::Regex; +use pyo3::prelude::*; + +use ahash::{AHashMap, AHashSet}; +use compact_str::CompactString; +use rayon::prelude::*; + +// Default GPT-4 style regex pattern for splitting text +const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; + +type Pair = (u32, u32); + +/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation +#[pyclass] +pub struct Tokenizer { + /// Maps pairs of token IDs to their merged token ID + pub merges: StdHashMap, + /// The regex pattern used for text splitting + pub pattern: String, + /// Compiled regex for efficiency + compiled_pattern: Regex, +} + +// ------------------------ internal helpers ------------------------ + +#[derive(Clone, Debug)] +struct Word { + ids: Vec, +} + +impl Word { + #[inline] + fn new(ids: Vec) -> Self { + Self { ids } + } + + #[inline] + fn pairs<'a>(&'a self) -> impl Iterator + 'a { + self.ids.windows(2).map(|w| (w[0], w[1])) + } + + /// Merge all non-overlapping occurrences of pair -> new_id. + /// Returns a small Vec of local pair-count deltas for THIS word only: + /// -1 for removed pairs, +1 for newly created pairs. + /// + /// NOTE: this version deliberately avoids a HashMap in the hot loop. + fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> { + let (a, b) = pair; + let n = self.ids.len(); + if n < 2 { + return Vec::new(); + } + + let mut out: Vec = Vec::with_capacity(n); + let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6); + + let mut i = 0; + while i < n { + if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b { + let left = out.last().copied(); + let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None }; + + // remove old pairs + if let Some(x) = left { + deltas.push(((x, a), -1)); + deltas.push(((x, new_id), 1)); + } + deltas.push(((a, b), -1)); + if let Some(y) = right { + deltas.push(((b, y), -1)); + deltas.push(((new_id, y), 1)); + } + + // write merged token + out.push(new_id); + i += 2; // skip 'a' and 'b' + } else { + out.push(self.ids[i]); + i += 1; + } + } + + self.ids = out; + deltas + } +} + +#[derive(Debug, Eq)] +struct MergeJob { + pair: Pair, + count: u64, + /// set of word indices where this pair may occur and needs processing + pos: AHashSet, +} + +impl PartialEq for MergeJob { + fn eq(&self, other: &Self) -> bool { + self.count == other.count && self.pair == other.pair + } +} + +impl PartialOrd for MergeJob { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MergeJob { + fn cmp(&self, other: &Self) -> Ordering { + // Max-heap by count; tie-break to ascending pair order (deterministic) + if self.count != other.count { + self.count.cmp(&other.count) + } else { + // ascending order on the pair when counts tie + other.pair.cmp(&self.pair) + } + } +} + +#[inline] +fn count_pairs_parallel( + words: &[Word], + counts: &[i32], +) -> (AHashMap, AHashMap>) { + words + .par_iter() + .enumerate() + .map(|(i, w)| { + let mut local_pc: AHashMap = AHashMap::new(); + let mut local_wtu: AHashMap> = AHashMap::new(); + if w.ids.len() >= 2 && counts[i] != 0 { + for (a, b) in w.pairs() { + *local_pc.entry((a, b)).or_default() += counts[i]; + local_wtu.entry((a, b)).or_default().insert(i); + } + } + (local_pc, local_wtu) + }) + .reduce( + || (AHashMap::new(), AHashMap::new()), + |(mut acc_pc, mut acc_wtu), (pc, wtu)| { + for (k, v) in pc { + *acc_pc.entry(k).or_default() += v; + } + for (k, s) in wtu { + acc_wtu.entry(k).or_default().extend(s); + } + (acc_pc, acc_wtu) + }, + ) +} + +// ------------------------ END helpers ------------------------ + +impl Tokenizer { + + /// Core incremental BPE training given unique words and their counts. + /// `words`: one entry per unique chunk (Vec of token-ids/bytes). + /// `counts`: same length as `words`, count per chunk. + fn train_core_incremental(&mut self, mut words: Vec, counts: Vec, vocab_size: u32) { + assert!(vocab_size >= 256, "vocab_size must be at least 256"); + let num_merges = vocab_size - 256; + log::info!("Starting BPE training: {} merges to compute", num_merges); + self.merges.clear(); + + // ---- Initial pair_counts and where_to_update (parallel) ---- + log::info!("Computing initial pair counts from {} unique sequences", words.len()); + let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts); + + // ---- Build heap ---- + log::info!("Building heap with {} unique pairs", pair_counts.len()); + let mut heap = OctonaryHeap::with_capacity(pair_counts.len()); + for (pair, pos) in where_to_update.drain() { + let c = *pair_counts.get(&pair).unwrap_or(&0); + if c > 0 { + heap.push(MergeJob { + pair, + count: c as u64, + pos, + }); + } + } + + // ---- Merge loop ---- + log::info!("Starting merge loop"); + let mut merges_done = 0u32; + let mut last_log_percent = 0u32; + + while merges_done < num_merges { + let Some(mut top) = heap.pop() else { break; }; + + // Lazy refresh + let current = *pair_counts.get(&top.pair).unwrap_or(&0); + if top.count != current as u64 { + top.count = current as u64; + if top.count > 0 { + heap.push(top); + } + continue; + } + if top.count == 0 { + break; + } + + // Record merge + let new_id = 256 + merges_done; + self.merges.insert(top.pair, new_id); + + // Merge this pair in all words where it occurs + let mut local_pos_updates: AHashMap> = AHashMap::new(); + for &word_idx in &top.pos { + // Apply merge to this word and collect pair-count deltas + let changes = words[word_idx].merge_pair(top.pair, new_id); + // Update global pair counts based on this word's count + for (pair, delta) in changes { + let delta_total = delta * counts[word_idx]; + if delta_total != 0 { + *pair_counts.entry(pair).or_default() += delta_total; + if delta > 0 { + local_pos_updates.entry(pair).or_default().insert(word_idx); + } + } + } + } + + // Add the updated pair counts back to the heap + for (pair, pos) in local_pos_updates { + let cnt = *pair_counts.get(&pair).unwrap_or(&0); + if cnt > 0 { + heap.push(MergeJob { + pair, + count: cnt as u64, + pos, + }); + } + } + + merges_done += 1; + + // Log progress every 1% + let current_percent = (merges_done * 100) / num_merges; + if current_percent > last_log_percent { + log::info!( + "Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})", + current_percent, merges_done, num_merges, top.pair, new_id, top.count + ); + last_log_percent = current_percent; + } + } + + log::info!("Finished training: {} merges completed", merges_done); + } +} + +/// Public methods for the Tokenizer class that will be exposed to Python. +#[pymethods] +impl Tokenizer { + /// Create a new Tokenizer + #[new] + pub fn new() -> Self { + Self { + merges: StdHashMap::new(), + pattern: String::new(), + compiled_pattern: Regex::new("").expect("Empty regex should be valid"), + } + } + + /// Train from a streaming iterator (parallel ingestion). + /// We refill a Rust Vec buffer under the GIL, then release the GIL + /// to do the heavy splitting and counting **in parallel** with rayon. + #[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))] + #[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")] + pub fn train_from_iterator( + &mut self, + py: pyo3::Python<'_>, + iterator: &pyo3::Bound<'_, pyo3::PyAny>, + vocab_size: u32, + buffer_size: usize, + pattern: Option, + ) -> PyResult<()> { + // Use provided pattern or default to GPT-4 pattern + let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string()); + + // Update the stored pattern and compile it + self.pattern = pattern_str.clone(); + self.compiled_pattern = Regex::new(&pattern_str) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?; + + // Prepare a true Python iterator object + let py_iter: pyo3::Py = unsafe { + pyo3::Bound::from_borrowed_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))? + .into() + }; + + // Global chunk counts + let mut counts: AHashMap = AHashMap::new(); + + // Temporary buffer we refill under the GIL + let mut buf: Vec = Vec::with_capacity(buffer_size); + + log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size); + let mut total_sequences = 0u64; + + // Helper: refill `buf` with up to `buffer_size` strings from the Python iterator. + // Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise. + let refill = |buf: &mut Vec| -> PyResult { + pyo3::Python::with_gil(|py| { + buf.clear(); + let it = py_iter.bind(py); + loop { + if buf.len() >= buffer_size { + return Ok(false); + } + // next(it) + let next_obj = unsafe { + pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr())) + }; + match next_obj { + Some(obj) => { + let s: String = obj.extract()?; + buf.push(s); + } + None => { + if pyo3::PyErr::occurred(py) { + return Err(pyo3::PyErr::fetch(py)); + } else { + return Ok(true); // exhausted + } + } + } + } + }) + }; + + // Stream ingestion loop: refill under GIL, process without GIL (parallel) + loop { + let exhausted = refill(&mut buf)?; + if buf.is_empty() && exhausted { + break; + } + + total_sequences += buf.len() as u64; + + let pattern = self.compiled_pattern.clone(); + let local: AHashMap = py.allow_threads(|| { + buf.par_iter() + .map(|s| { + let mut m: AHashMap = AHashMap::new(); + for mat in pattern.find_iter(s) { + let piece = mat.expect("regex match failed").as_str(); + *m.entry(CompactString::from(piece)).or_default() += 1; + } + m + }) + .reduce( + || AHashMap::new(), + |mut a, b| { + for (k, v) in b { + *a.entry(k).or_default() += v; + } + a + }, + ) + }); + + // Merge local into global (single-threaded) + for (k, v) in local { + *counts.entry(k).or_default() += v; + } + + if exhausted { + break; + } + } + log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len()); + + // Materialize words & counts + let mut words = Vec::with_capacity(counts.len()); + let mut cvec = Vec::with_capacity(counts.len()); + for (chunk, c) in counts.into_iter() { + words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect())); + cvec.push(c); + } + + self.train_core_incremental(words, cvec, vocab_size); + Ok(()) + } + + /// Return the regex pattern + pub fn get_pattern(&self) -> String { + self.pattern.clone() + } + + /// Return the mergeable ranks (token bytes -> token id / rank) + pub fn get_mergeable_ranks(&self) -> Vec<(Vec, u32)> { + let mut mergeable_ranks = Vec::new(); + + // Build vocabulary incrementally from low to high token IDs + let mut token_bytes: Vec> = (0..256_u32).map(|i| vec![i as u8]).collect(); + + for (i, bytes) in token_bytes.iter().enumerate() { + mergeable_ranks.push((bytes.clone(), i as u32)); + } + + // Sort merges by token id (so we can reconstruct bytes progressively) + let mut sorted_merges: Vec<_> = self.merges.iter().collect(); + sorted_merges.sort_by_key(|&(_, &token_id)| token_id); + + for (&pair, &merged_id) in sorted_merges { + let (left, right) = pair; + let mut merged_bytes = token_bytes[left as usize].clone(); + merged_bytes.extend(&token_bytes[right as usize]); + + if token_bytes.len() <= merged_id as usize { + token_bytes.resize(merged_id as usize + 1, Vec::new()); + } + token_bytes[merged_id as usize] = merged_bytes.clone(); + + mergeable_ranks.push((merged_bytes, merged_id)); + } + + mergeable_ranks + } + + /// Encode a string into token IDs + pub fn encode(&self, text: &str) -> Vec { + let mut all_ids = Vec::new(); + + // Split text using the regex pattern + for m in self.compiled_pattern.find_iter(text) { + let chunk = m.expect("regex match failed").as_str(); + + // Convert chunk to bytes then to u32 IDs + let mut ids: Vec = chunk.bytes().map(|b| b as u32).collect(); + + // Apply merges iteratively + while ids.len() >= 2 { + // Find the best pair to merge + let mut best_pair: Option<(usize, Pair, u32)> = None; + + for i in 0..ids.len() - 1 { + let pair: Pair = (ids[i], ids[i + 1]); + if let Some(&new_id) = self.merges.get(&pair) { + if best_pair.is_none() || new_id < best_pair.unwrap().2 { + best_pair = Some((i, pair, new_id)); + } + } + } + + // If we found a pair to merge, apply it + if let Some((idx, _pair, new_id)) = best_pair { + ids[idx] = new_id; + ids.remove(idx + 1); + } else { + // No more merges possible + break; + } + } + + all_ids.extend(ids); + } + + all_ids + } +} + +#[pymodule] +fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> { + pyo3_log::init(); // forwards Rust `log` to Python's `logging` + m.add_class::()?; + Ok(()) +} diff --git a/scripts/base_eval.py b/scripts/base_eval.py new file mode 100644 index 0000000..a566d49 --- /dev/null +++ b/scripts/base_eval.py @@ -0,0 +1,180 @@ +""" +Evlauate the CORE metric for a given model. + +Run on a single GPU: +python base_eval.py + +Run with torchrun on e.g. 8 GPUs: +torchrun --nproc_per_node=8 base_eval.py + +The script will print the CORE metric to the console. +""" +import os +import sys +import time +import json +import random +import yaml + +import pandas as pd +import torch + +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir +from nanochat.tokenizer import HuggingFaceTokenizer +from nanochat.checkpoint_manager import load_model +from nanochat.core_eval import evaluate_task + +# ----------------------------------------------------------------------------- +# nanoChat specific function dealing with I/O etc. + +def evaluate_model(model, tokenizer, device, max_per_task=-1): + """ + Evaluate a base model on the CORE benchmark. + - max_per_task: crop the data to this many examples per task for testing (-1 = disable) + TODO: clean up this function, delete the need for all the files, for pandas dependency, etc. + """ + # Load config and task metadata + base_dir = get_base_dir() + eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + config_path = os.path.join(eval_bundle_dir, "core.yaml") + data_base_path = os.path.join(eval_bundle_dir, "eval_data") + eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + tasks = config['icl_tasks'] + eval_metadata = pd.read_csv(eval_meta_data) + + # Evaluate each task + results = {} + centered_results = {} + for task in tasks: + start_time = time.time() + label = task['label'] + task_meta = { + 'task_type': task['icl_task_type'], + 'dataset_uri': task['dataset_uri'], + 'num_fewshot': task['num_fewshot'][0], + 'continuation_delimiter': task.get('continuation_delimiter', ' ') + } + print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='') + + # Load data for this task + data_path = os.path.join(data_base_path, task_meta['dataset_uri']) + with open(data_path, 'r') as f: + data = [json.loads(line.strip()) for line in f] + + # shuffle the data because in many cases it appears ordered but we want + # the abillity to only run a subset of the data for debugging purposes etc. + shuffle_rng = random.Random(1337) + shuffle_rng.shuffle(data) + if max_per_task > 0: + data = data[:max_per_task] + + # run the evaluation for this task + accuracy = evaluate_task(model, tokenizer, data, device, task_meta) + + results[label] = accuracy + row = eval_metadata[eval_metadata["Eval Task"] == label] + random_baseline = row["Random baseline"].values[0] + centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline) + centered_results[label] = centered_result + end_time = time.time() + print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s") + + core_metric = sum(centered_results.values()) / len(centered_results) + out = { + "results": results, + "centered_results": centered_results, + "core_metric": core_metric + } + return out + +# ----------------------------------------------------------------------------- +# HuggingFace loading utilities and light wrappers for a model + +class ModelWrapper: + """Lightweight wrapper for a HuggingFace model""" + def __init__(self, model, max_seq_len=None): + self.model = model + self.max_seq_len = max_seq_len + + def __call__(self, input_ids): + outputs = self.model(input_ids) + logits = outputs.logits + return logits + +def load_hf_model(hf_path: str, device): + print0(f"Loading model from: {hf_path}") + # Load the model + from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(hf_path) + model.to(device) + model.eval() + max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None + model = ModelWrapper(model, max_seq_len=max_seq_len) + # Load the tokenizer + tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path) + return model, tokenizer + +# ----------------------------------------------------------------------------- +def main(): + assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]" + + # distributed / precision setup + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() + autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + + # Load model and tokenizer from command line or from file system + if len(sys.argv) >= 2: + # atm assume that if a path is given, it's a huggingface model path + hf_path = sys.argv[1] + print0(f"Loading huggingface model from: {hf_path}") + model, tokenizer = load_hf_model(hf_path, device) + model_name = hf_path # just for logging + model_slug = hf_path.replace("/", "-") # for the output csv file + else: + # load a local model from the file system + model, tokenizer, meta = load_model("base", device, phase="eval") + model_name = f"base_model (step {meta['step']})" # just for logging + model_slug = f"base_model_{meta['step']:06d}" # for the output csv file + + # Evaluate the model + with autocast_ctx: + out = evaluate_model(model, tokenizer, device) + + # Write out the results to a csv file + core_metric = None + centered_results = {} + if ddp_rank == 0: + base_dir = get_base_dir() + output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv") + os.makedirs(os.path.dirname(output_csv_path), exist_ok=True) + results = out["results"] + centered_results = out["centered_results"] + core_metric = out["core_metric"] + with open(output_csv_path, 'w') as f: + f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n") + for label in results: + f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n") + f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n") + # Print the content of the csv file to console too + print0("="*80) + print0(f"Model: {model_name}") + print0("="*80) + with open(output_csv_path, 'r') as f: + print0(f.read()) + + # Log to report + from nanochat.report import get_report + get_report().log(section="Base model evaluation", data=[ + { + "Model": model_name, + "CORE metric": core_metric, + }, + centered_results, # the full table + ]) + + compute_cleanup() + +if __name__ == "__main__": + main() diff --git a/scripts/base_loss.py b/scripts/base_loss.py new file mode 100644 index 0000000..ba3876d --- /dev/null +++ b/scripts/base_loss.py @@ -0,0 +1,78 @@ +""" +Loads a checkpoint, and: +- Evaluates the loss on a larger chunk of train/val splits +- Samples from the model + +Example run as: +torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +""" +import os +import torch +from nanochat.checkpoint_manager import load_model +from nanochat.common import compute_init, print0, compute_cleanup +from nanochat.dataloader import tokenizing_distributed_data_loader +from nanochat.tokenizer import get_token_bytes +from nanochat.loss_eval import evaluate_bpb +from nanochat.engine import Engine + +# Configuration +device_batch_size = 32 +split_tokens = 20*524288 # number of tokens to evaluate per split +model_tag = None # optional model tag for the output directory name +model_step = None # optional model step for the output directory name +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file + +# Load the base model and the tokenizer +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) +sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really + +# Set up the precision we'll run with +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + +# Evaluate the loss on each split +tokens_per_step = device_batch_size * sequence_len * ddp_world_size +assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step" +steps = split_tokens // tokens_per_step +token_bytes = get_token_bytes(device=device) +bpb_results = {} +for split_name in ["train", "val"]: + loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name) + with autocast_ctx: + bpb = evaluate_bpb(model, loader, steps, token_bytes) + print0(f"{split_name} bpb: {bpb:.4f}") + bpb_results[split_name] = bpb + +# Master process also samples from the model +samples = [] +if ddp_rank == 0: + prompts = [ + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", + ] + engine = Engine(model, tokenizer) + for prompt in prompts: + tokens = tokenizer(prompt, prepend="<|bos|>") + with autocast_ctx: + sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) + sample_str = tokenizer.decode(sample[0]) + print0(sample_str) + samples.append(sample_str) + +# Log to report +from nanochat.report import get_report +get_report().log(section="Base model loss", data=[ + { + "train bpb": bpb_results["train"], + "val bpb": bpb_results["val"], + }, + {f"sample {i}": sample for i, sample in enumerate(samples)}, +]) + +# Cleanup +compute_cleanup() diff --git a/scripts/base_train.py b/scripts/base_train.py new file mode 100644 index 0000000..b691ed4 --- /dev/null +++ b/scripts/base_train.py @@ -0,0 +1,339 @@ +""" +Train model. Run as: + +python base_train.py + +or distributed as: + +torchrun --nproc_per_node=8 base_train.py +""" + +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import time +import wandb +import torch + +from nanochat.gpt import GPT, GPTConfig +from nanochat.dataloader import tokenizing_distributed_data_loader +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir +from nanochat.tokenizer import get_tokenizer, get_token_bytes +from nanochat.checkpoint_manager import save_checkpoint +from nanochat.loss_eval import evaluate_bpb +from nanochat.engine import Engine +from scripts.base_eval import evaluate_model +print_banner() + +# ----------------------------------------------------------------------------- +# User settings +run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +# Model architecture +depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived +max_seq_len = 2048 # max context length +# Training horizon. Only one of these 3 will be used, in this order of precedence. +num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) +target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) +target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) +# Optimization +device_batch_size = 32 # per-device batch size (set to not OOM) +total_batch_size = 524288 # total desired batch size, in #tokens +embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) +unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam) +weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam) +matrix_lr = 0.02 # learning rate for the matrix parameters (Muon) +grad_clip = 1.0 # gradient clipping value (0.0 = disabled) +# Evaluation +eval_every = 250 # every how many steps to evaluate the model for val bpb +eval_tokens = 20*524288 # number of tokens to evaluate val loss on +core_metric_every = 2000 # every how many steps to evaluate the core metric +core_metric_max_per_task = 500 # examples per task in estimating the core metric +sample_every = 2000 # every how many steps to sample from the model +# Output +model_tag = "" # optionally override the model tag for the output checkpoint directory name +# now allow CLI to override the settings via the configurator lol +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # will be useful for logging +# ----------------------------------------------------------------------------- + +# Compute init +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + +# wandb logging init +use_dummy_wandb = run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config) + +# Tokenizer will be useful for evaluation, also we need the vocab size +tokenizer = get_tokenizer() +token_bytes = get_token_bytes(device=device) +vocab_size = tokenizer.get_vocab_size() +print0(f"Vocab size: {vocab_size:,}") + +# Model kwargs are derived from the desired depth of the model +num_layers = depth +model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) +num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) +num_kv_heads = num_heads # 1:1 MQA ratio +print0(f"num_layers: {num_layers}") +print0(f"model_dim: {model_dim}") +print0(f"num_heads: {num_heads}") +print0(f"num_kv_heads: {num_kv_heads}") + +# Optimizer / data / training length related hyperparameters +# figure out the needed gradient accumulation to reach the desired total batch size +tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +# ----------------------------------------------------------------------------- +# Initialize the Model +model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) +with torch.device("meta"): + model_config = GPTConfig(**model_config_kwargs) + model = GPT(model_config) +model.to_empty(device="cuda") +model.init_weights() +orig_model = model # original, uncompiled model, for saving raw model state_dict +model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through +num_params = sum(p.numel() for p in model.parameters()) +print0(f"Number of parameters: {num_params:,}") +num_flops_per_token = model.estimate_flops() +print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") + +# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) +assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0 +if num_iterations > 0: + print0(f"Using user-provided number of iterations: {num_iterations:,}") +elif target_flops > 0: + # calculate the number of iterations from the target flops + num_iterations = round(target_flops / (num_flops_per_token * total_batch_size)) + print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") +elif target_param_data_ratio > 0: + # calculate the number of iterations from the target param data ratio + target_tokens = target_param_data_ratio * num_params + num_iterations = target_tokens // total_batch_size + print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") +else: + raise ValueError("No training horizon specified") +total_tokens = total_batch_size * num_iterations +print0(f"Total number of training tokens: {total_tokens:,}") +print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20 +print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") + +# ----------------------------------------------------------------------------- +# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) +optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) +adamw_optimizer, muon_optimizer = optimizers + +# Initialize the DataLoaders for train/val +base_dir = get_base_dir() +tokens_dir = os.path.join(base_dir, "tokenized_data") +train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train") +build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val") +x, y = next(train_loader) # kick off load of the very first batch of data + +# ----------------------------------------------------------------------------- +# Set up hyperparameter schedulers + +# Learning rate scheduler +# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement) +warmup_ratio = 0.0 # ratio of iterations for LR warmup +warmdown_ratio = 0.2 # ratio of iterations for LR warmdown +final_lr_frac = 0.0 # final LR is this fraction of the initial LR +def get_lr_multiplier(it): + warmup_iters = round(warmup_ratio * num_iterations) + warmdown_iters = round(warmdown_ratio * num_iterations) + if it < warmup_iters: + return (it + 1) / warmup_iters + elif it <= num_iterations - warmdown_iters: + return 1.0 + else: + progress = (num_iterations - it) / warmdown_iters + return progress * 1.0 + (1 - progress) * final_lr_frac + +# Momentum scheduler for Muon optimizer +def get_muon_momentum(it): + frac = min(it / 300, 1) + momentum = (1 - frac) * 0.85 + frac * 0.95 + return momentum + +# ----------------------------------------------------------------------------- +# Training loop +min_val_bpb = float("inf") +smooth_train_loss = 0 # EMA of training loss +ema_beta = 0.9 # EMA decay factor +total_training_time = 0 # total wall-clock time of training +# note that we run +1 steps only so that we can eval and save at the end +for step in range(num_iterations + 1): + last_step = step == num_iterations + flops_so_far = num_flops_per_token * total_batch_size * step + + # once in a while: evaluate the val bpb (all ranks participate) + if last_step or step % eval_every == 0: + model.eval() + val_loader = build_val_loader() + eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) + with autocast_ctx: + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") + if val_bpb < min_val_bpb: + min_val_bpb = val_bpb + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + }) + model.train() + + # once in a while: estimate the CORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + if last_step or (step > 0 and step % core_metric_every == 0): + model.eval() + with autocast_ctx: + results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task) + print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "core_metric": results["core_metric"], + "centered_results": results["centered_results"], + }) + model.train() + + # once in a while: sample from the model (only on master process) + # use the original uncompiled model because the inputs keep changing shape + if master_process and (last_step or (step > 0 and step % sample_every == 0)): + model.eval() + prompts = [ + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", + ] + engine = Engine(model, tokenizer) + for prompt in prompts: + tokens = tokenizer(prompt, prepend="<|bos|>") + with autocast_ctx: + sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) + print0(tokenizer.decode(sample[0])) + model.train() + + # save checkpoint at the end of the run (only on master process) + if master_process and last_step: + output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 + checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), + [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly + { + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": model_config_kwargs, + "user_config": user_config, # inputs to the training script + "device_batch_size": device_batch_size, + "max_seq_len": max_seq_len, + } + ) + + if last_step: + break + + # ------------------------------------------------------------------------- + # single training step + # evaluate the gradient + torch.cuda.synchronize() + t0 = time.time() + for micro_step in range(grad_accum_steps): + with autocast_ctx: + loss = model(x, y) + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss.backward() + x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + # gradient clipping (TODO possibly expertiment with) + if grad_clip > 0.0: + torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) + # step the optimizers + lrm = get_lr_multiplier(step) + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lrm + muon_momentum = get_muon_momentum(step) + for group in muon_optimizer.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) + torch.cuda.synchronize() + t1 = time.time() + dt = t1 - t0 + # ------------------------------------------------------------------------- + + # logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + pct_done = 100 * step / num_iterations + tok_per_sec = int(world_tokens_per_fwdbwd / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt + promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity + mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + if step > 10: + total_training_time += dt # only count the time after the first 10 steps + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + if step % 100 == 0: + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + }) + +# print a few more stats +print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +print0(f"Total training time: {total_training_time/60:.2f}m") +print0(f"Minimum validation bpb: {min_val_bpb:.4f}") + +# Log to report +from nanochat.report import get_report +get_report().log(section="Base model training", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of parameters": num_params, + "Number of FLOPs per token": f"{num_flops_per_token:e}", + "Calculated number of iterations": num_iterations, + "Number of training tokens": total_tokens, + "Tokens : Params ratio": total_batch_size * num_iterations / num_params, + "DDP world size": ddp_world_size, + "warmup_ratio": warmup_ratio, + "warmdown_ratio": warmdown_ratio, + "final_lr_frac": final_lr_frac, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + "Final validation bpb": val_bpb, + "CORE metric estimate": results["core_metric"], + "MFU %": f"{mfu:.2f}%", + "Total training flops": f"{flops_so_far:e}", + "Total training time": f"{total_training_time/60:.2f}m", + "Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB", + } +]) + +# cleanup +wandb_run.finish() # wandb run finish +compute_cleanup() diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py new file mode 100644 index 0000000..3a38147 --- /dev/null +++ b/scripts/chat_cli.py @@ -0,0 +1,99 @@ +""" +New and upgraded chat mode because a lot of the code has changed since the last one. + +Intended to be run single GPU only atm: +python -m scripts.chat_cli -i mid +""" +import argparse +import torch +from nanochat.common import compute_init +from nanochat.engine import Engine +from nanochat.checkpoint_manager import load_model + +parser = argparse.ArgumentParser(description='Chat with the model') +parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") +parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') +parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') +parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') +parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') +parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') +args = parser.parse_args() + +# Init the model and tokenizer +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) + +# Special tokens for the chat state machine +bos = tokenizer.get_bos_token_id() +user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>") +assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>") + +# Create Engine for efficient generation +engine = Engine(model, tokenizer) + +print("\nNanoChat Interactive Mode") +print("-" * 50) +print("Type 'quit' or 'exit' to end the conversation") +print("Type 'clear' to start a new conversation") +print("-" * 50) + +conversation_tokens = [bos] + +while True: + + if args.prompt: + # Get the prompt from the launch command + user_input = args.prompt + else: + # Get the prompt interactively from the console + try: + user_input = input("\nUser: ").strip() + except (EOFError, KeyboardInterrupt): + print("\nGoodbye!") + break + + # Handle special commands + if user_input.lower() in ['quit', 'exit']: + print("Goodbye!") + break + + if user_input.lower() == 'clear': + conversation_tokens = [bos] + print("Conversation cleared.") + continue + + if not user_input: + continue + + # Add User message to the conversation + conversation_tokens.append(user_start) + conversation_tokens.extend(tokenizer.encode(user_input)) + conversation_tokens.append(user_end) + + # Kick off the assistant + conversation_tokens.append(assistant_start) + generate_kwargs = { + "num_samples": 1, + "max_tokens": 256, + "temperature": args.temperature, + "top_k": args.top_k, + } + response_tokens = [] + print("\nAssistant: ", end="", flush=True) + with autocast_ctx: + for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): + token = token_column[0] # pop the batch dimension (num_samples=1) + response_tokens.append(token) + token_text = tokenizer.decode([token]) + print(token_text, end="", flush=True) + print() + # we have to ensure that the assistant end token is the last token + # so even if generation ends due to max tokens, we have to append it to the end + if response_tokens[-1] != assistant_end: + response_tokens.append(assistant_end) + conversation_tokens.extend(response_tokens) + + # In the prompt mode, we only want a single response and exit + if args.prompt: + break diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py new file mode 100644 index 0000000..df6a01a --- /dev/null +++ b/scripts/chat_eval.py @@ -0,0 +1,251 @@ +""" +Evaluate the Chat model. +All the generic code lives here, and all the evlauation-specific +code lives in nanochat directory and is imported from here. + +Example runs: +python -m scripts.chat_eval -a ARC-Easy +torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy +""" + +import argparse +from functools import partial + +import torch +import torch.distributed as dist + +from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0 +from nanochat.checkpoint_manager import load_model +from nanochat.engine import Engine + +from tasks.humaneval import HumanEval +from tasks.mmlu import MMLU +from tasks.arc import ARC +from tasks.gsm8k import GSM8K + +# ----------------------------------------------------------------------------- +# Generative evaluation loop (we go one problem at a time, sample, evaluate) + +def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None): + + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + device = model.get_device() + + num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems) + + # Run the evaluation + num_passed, total = 0, 0 + for i in range(ddp_rank, num_problems, ddp_world_size): + conversation = task_object[i] + + # Tokenize the prompt + encoded_prompt = tokenizer.render_for_completion(conversation) + # Get the completions + results, _ = engine.generate_batch( + encoded_prompt, + num_samples=num_samples, + max_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + # Decode the completions as text + prefix_length = len(encoded_prompt) + completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results] + # Evaluate success criteria + outcomes = [task_object.evaluate(conversation, completion) for completion in completions] + passed = any(outcomes) + + # Keep stats + total += 1 + num_passed += int(passed) + + # Logging (overwrite the same line in the console) + print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True) + + # Finish the in-place progress line with a newline before final summary + print() + + # Aggregate results across all ranks + if ddp: + num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device) + total_tensor = torch.tensor([total], dtype=torch.long, device=device) + dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) + num_passed = num_passed_tensor.item() + total = total_tensor.item() + + print0("=" * 50) + print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)") + + # Return the accuracy + return num_passed/total + +# ----------------------------------------------------------------------------- +# Categorical evaluation loop +# A lot easier because we don't have to sample. Therefore, we can actually go +# batches at a time and just check the logits for correct answer choices. + +def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None): + + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + device = model.get_device() + bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored + + # We'll process batches of independent problems at a time because there is no sampling needed + num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems) + ceil_div = lambda x, y: -(-x // y) + num_batches = ceil_div(num_problems, batch_size) + + # Run the evaluation + letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work + num_passed, total = 0, 0 + for i in range(ddp_rank, num_batches, ddp_world_size): + i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems) + + # Prepare the batch of problems. They might all be of different length, so we pad/collate them. + conversations = [task_object[ii] for ii in range(i0, i1)] + prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works + max_length = max(len(ids) for ids in prompt_ids) + answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer) + padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids] + prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device) + + # Get the logits for the whole batch of conversations in parallel (efficiency win here) + with torch.no_grad(): + logits = model(prompt_ids) # (B, T, V) + + # Focus on the available answer on just the letters corresponding to choices + # Note that this helps the evaluation a lot because it specifically narrows the focus to only the avilable letters + # The much harder alternative would be to just generate from the Assistant and check if it responded with the correct + # letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way. + for idx, conversation in enumerate(conversations): + # get the token ids of all the available letters of this problem + letters = conversation['letters'] + letter_ids = [] + for letter in letters: + if not letter in letter_to_id_cache: + encoded_letter = tokenizer.encode(letter) + assert len(encoded_letter) == 1, "Each letter must be a single token" + letter_to_id_cache[letter] = encoded_letter[0] + letter_ids.append(letter_to_id_cache[letter]) + # focus logits just down to the answer position and the available letters of the answer + answer_pos = answer_time_positions[idx] + focus_logits = logits[idx, answer_pos, letter_ids] + # get the argmax letter (the predicted answer) + argmax_letter_id = focus_logits.argmax(dim=-1).item() + predicted_letter = letters[argmax_letter_id] + # evaluate the outcome + outcome = task_object.evaluate(conversation, predicted_letter) + num_passed += int(outcome) + total += 1 + + # Aggregate results across all ranks + if ddp: + num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device) + total_tensor = torch.tensor([total], dtype=torch.long, device=device) + dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) + num_passed = num_passed_tensor.item() + total = total_tensor.item() + + average = num_passed/total + print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)") + return average + +# ----------------------------------------------------------------------------- + +def run_chat_eval(task_name, model, tokenizer, engine, + batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50, + max_problems=None): + # Create the evaluation object + task_module = { + 'HumanEval': HumanEval, + 'MMLU': partial(MMLU, subset="all", split="test"), + 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"), + 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), + 'GSM8K': partial(GSM8K, subset="main", split="test"), + }[task_name] + task_object = task_module() + # Run the evaluation + if task_object.eval_type == 'generative': + acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems) + elif task_object.eval_type == 'categorical': + acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems) + else: + raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}") + return acc + +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl") + parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") + parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) + parser.add_argument('-t', '--temperature', type=float, default=0.0) + parser.add_argument('-m', '--max-new-tokens', type=int, default=512) + parser.add_argument('-n', '--num-samples', type=int, default=1) + parser.add_argument('-k', '--top-k', type=int, default=50) + parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation') + parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') + parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') + parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate') + args = parser.parse_args() + + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() + ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 + autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) + + model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) + engine = Engine(model, tokenizer) + + # Get the tasks to evaluate on + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval'] + baseline_accuracies = { + 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% + 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% + 'MMLU': 0.25, # multiple choice 1 of 4 => 25% + 'GSM8K': 0.0, # open-ended => 0% + 'HumanEval': 0.0, # open-ended => 0% + } + task_names = all_tasks if args.task_name is None else args.task_name.split('|') + + # Run all the task evaluations sequentially + results = {} + for task_name in task_names: + with autocast_ctx: + acc = run_chat_eval( + task_name, + model, tokenizer, engine, + batch_size=args.batch_size, + num_samples=args.num_samples, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + max_problems=args.max_problems, + ) + results[task_name] = acc + print0(f"{task_name} accuracy: {100 * acc:.2f}%") + + # Log to report + from nanochat.report import get_report + all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks) + # calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy) + # this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance) + chatcore_metric_dict = {} + if all_tasks_were_evaluated: + centered_mean = 0 + for task_name, acc in results.items(): + baseline_acc = baseline_accuracies.get(task_name, 0.0) + centered_acc = (acc - baseline_acc) / (1.0 - baseline_acc) + centered_mean += centered_acc + chatcore_metric = centered_mean / len(results) + chatcore_metric_dict = {"ChatCORE metric": chatcore_metric} + get_report().log(section="Chat evaluation " + args.source, data=[ + vars(args), # CLI args + results, + chatcore_metric_dict, + ]) + + compute_cleanup() diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py new file mode 100644 index 0000000..af70bda --- /dev/null +++ b/scripts/chat_rl.py @@ -0,0 +1,331 @@ +""" +Reinforcement learning on GSM8K via "GRPO". + +I put GRPO in quotes because we actually end up with something a lot +simpler and more similar to just REINFORCE: + +1) Delete trust region, so there is no KL regularization to a reference model +2) We are on policy, so there's no need for PPO ratio+clip. +3) We use GAPO style normalization that is token-level, not sequence-level. +4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage. + +1 GPU: +python -m scripts.chat_rl + +8 GPUs: +torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default +""" + +import os +import itertools +import re +import wandb +import torch +import torch.distributed as dist + +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb +from nanochat.checkpoint_manager import save_checkpoint, load_model +from nanochat.engine import Engine +from tasks.gsm8k import GSM8K + +# RL hyperparameters +run = "dummy" # wandb run name +source = "sft" # mid|sft +dtype = "bfloat16" +device_batch_size = 8 # no forward pass will go above this to not OOM +examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!) +num_samples = 16 # number of samples per example (/question) +max_new_tokens = 256 +temperature = 1.0 +top_k = 50 # TODO: try None? +unembedding_lr = 0.004 +embedding_lr = 0.2 +matrix_lr = 0.02 +weight_decay = 0.0 +init_lr_frac = 0.05 +num_epochs = 1 # how many epochs of gsm8k to train on +save_every = 60 # every how many steps to save the model +eval_every = 60 # every how many steps to evaluate the model for val pass@k +eval_examples = 400 # number of examples used for evaluating pass@k +# now allow CLI to override the settings via the configurator lol +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # will be useful for logging +# ----------------------------------------------------------------------------- + +# Init compute/precision +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. +dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) + +# wandb logging init +use_dummy_wandb = run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config) + +# Init model and tokenizer +model, tokenizer, meta = load_model(source, device, phase="eval") +engine = Engine(model, tokenizer) # for sampling rollouts + +# ----------------------------------------------------------------------------- +# Rollout / sampling generator loop that yields batches of examples for training + +train_task = GSM8K(subset="main", split="train") +val_task = GSM8K(subset="main", split="test") +num_steps = (len(train_task) // examples_per_step) * num_epochs +print0(f"Calculated number of steps: {num_steps}") + +@torch.no_grad() +def get_batch(): + assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss. + rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data + for example_idx in itertools.cycle(rank_indices): + + # First get the full conversation of both user and assistant messages + conversation = train_task[example_idx] + + # Tokenize the conversation, deleting the last Assistant message and priming the Assistant for a completion instead + # (i.e. keep the <|assistant_start|>, but delete everything after it) + tokens = tokenizer.render_for_completion(conversation) + prefix_length = len(tokens) + + # Generate num_samples samples using batched generation, use loop to avoid OOMs + model.eval() # ensure the model is in eval mode + generated_token_sequences = [] + masks = [] + num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs + for sampling_step in range(num_sampling_steps): + seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32 + with autocast_ctx: + generated_token_sequences_batch, masks_batch = engine.generate_batch( + tokens, + num_samples=device_batch_size, + max_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + seed=seed, # must make sure to change the seed for each sampling step + ) + generated_token_sequences.extend(generated_token_sequences_batch) + masks.extend(masks_batch) + + # Calculate the rewards for each sample + rewards = [] + for sample_tokens in generated_token_sequences: + # Get just the generated tokens (after the prompt) + generated_tokens = sample_tokens[prefix_length:] + # Decode the generated response + generated_text = tokenizer.decode(generated_tokens) + # Calculate the reward + reward = train_task.reward(conversation, generated_text) + rewards.append(reward) + + # Pad the sequences so that their lengths (in time) match + max_length = max(len(seq) for seq in generated_token_sequences) + padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences] + padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks] + # Stack up the sequences and masks into PyTorch tensors + ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device) + mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device) + # Generate autoregressive inputs and targets to the Transformer + inputs = ids[:, :-1] + targets = ids[:, 1:].clone() # clone to avoid in-place modification: + targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index + # NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens. + # So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens. + rewards = torch.tensor(rewards, dtype=torch.float, device=device) + # Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma) + mu = rewards.mean() + advantages = rewards - mu + # yield inputs/targets as (B, T) of ids and rewards as (B,) of floats + yield generated_token_sequences, inputs, targets, rewards, advantages + +# ----------------------------------------------------------------------------- +# Simple evaluation loop for GSM8K pass@k +def run_gsm8k_eval(task, tokenizer, engine, + max_examples=None, + num_samples=1, + max_completion_tokens=256, + temperature=0.0, + top_k=50 +): + """ + Evaluates GSM8K task and returns a list of records of evaluation outcomes. + In a distributed setting, all ranks cooperate but this function will NOT + do the reduction across ranks. This is the responsibility of the caller. + Because the evaluation can take a while, this function will yield records one by one. + """ + max_examples = min(max_examples, len(task)) if max_examples is not None else len(task) + for idx in range(ddp_rank, max_examples, ddp_world_size): + conversation = task[idx] + tokens = tokenizer.render_for_completion(conversation) + prefix_length = len(tokens) + # Generate k samples using batched generation inside the Engine + assert num_samples <= device_batch_size # usually this is true. we can add a loop if not... + generated_token_sequences, masks = engine.generate_batch( + tokens, + num_samples=num_samples, + max_tokens=max_completion_tokens, + temperature=temperature, + top_k=top_k + ) + # Check each sample for correctness + outcomes = [] + for sample_tokens in generated_token_sequences: + generated_tokens = sample_tokens[prefix_length:] + generated_text = tokenizer.decode(generated_tokens) + is_correct = task.evaluate(conversation, generated_text) + outcomes.append({ + "is_correct": is_correct + }) + # A bit bloated because I wanted to do more complex logging at one point. + record = { + "idx": idx, + "outcomes": outcomes, + } + yield record + +# ----------------------------------------------------------------------------- +# Training loop + +# Init the optimizer +optimizers = model.setup_optimizers( + unembedding_lr=unembedding_lr, + embedding_lr=embedding_lr, + matrix_lr=matrix_lr, + weight_decay=weight_decay, +) + +# Set the initial learning rate as a fraction of the base learning rate +for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["lr"] * init_lr_frac + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + +# Learning rate scheduler: simple rampdown to zero over num_steps +def get_lr_multiplier(it): + lrm = 1.0 - it / num_steps + return lrm + +# Calculate the number of examples each rank handles to achive the desired examples_per_step +print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step +assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks" +examples_per_rank = examples_per_step // ddp_world_size # per GPU +print0(f"Calculated examples per rank: {examples_per_rank}") + +# Kick off the training loop +batch_iterator = get_batch() +for step in range(num_steps): + + # Evaluate the model once in a while and log to wandb + if step % eval_every == 0: + model.eval() + passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size + with autocast_ctx: + records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0) + records = list(records_iter) # collect all records + for k in range(1, device_batch_size + 1): + passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records) + num_records = torch.tensor(len(records), dtype=torch.long, device=device) + if ddp: + dist.all_reduce(num_records, op=dist.ReduceOp.SUM) + dist.all_reduce(passk, op=dist.ReduceOp.SUM) + passk = passk / num_records.item() # normalize by the total number of records + print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)] + print0(f"Step {step} | {', '.join(print_passk)}") + log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)} + wandb_run.log({ + "step": step, + **log_passk, + }) + + # Forward/Backward on rollouts over multiple examples in the dataset + rewards_list = [] + sequence_lengths = [] + for example_step in range(examples_per_rank): + # Get one batch corresponding to one example in the training dataset + sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator) + # Evaluate the loss and gradients + model.train() # ensure the model is in train mode + # We need one more loop because we can never exceed the device_batch_size + assert inputs_all.size(0) % device_batch_size == 0 + num_passes = inputs_all.size(0) // device_batch_size + for pass_idx in range(num_passes): + # Pluck out the batch for this pass + b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size + inputs = inputs_all[b0:b1] + targets = targets_all[b0:b1] + rewards = rewards_all[b0:b1] + advantages = advantages_all[b0:b1] + # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate + with autocast_ctx: + logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) + # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0. + pg_obj = (logp * advantages.unsqueeze(-1)).sum() + # normalize by the number of valid tokens, number of passes, and examples_per_rank + num_valid = (targets >= 0).sum().clamp(min=1) + pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank) + # Note, there is no need to add PPO ratio+clip because we are on policy + # Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize) + loss = -pg_obj + loss.backward() + print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}") + # For logging + rewards_list.append(rewards_all.mean().item()) + sequence_lengths.extend(len(seq) for seq in sequences_all) + + # A bunch of logging for how the rollouts went this step + mean_reward = sum(rewards_list) / len(rewards_list) + mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths) + if ddp: # aggregate across ranks + mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device) + mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device) + dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG) + dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG) + mean_reward = mean_reward_tensor.item() + mean_sequence_length = mean_sequence_length_tensor.item() + print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}") + wandb_run.log({ + "step": step, + "reward": mean_reward, + "sequence_length": mean_sequence_length, + }) + + # Update the model parameters + lrm = get_lr_multiplier(step) + for opt in optimizers: # first set the learning rate + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lrm + for opt in optimizers: # then step the optimizers + opt.step() + model.zero_grad(set_to_none=True) + wandb_run.log({ + "step": step, + "lrm": lrm, + }) + + # Master process saves the model once in a while. Skip first step. Save last step. + if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1): + base_dir = get_base_dir() + depth = model.config.n_layer + model_tag = f"d{depth}" # base the model tag on the depth of the base model + checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag) + model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer + save_checkpoint( + checkpoint_dir, + step, + model.state_dict(), + None, # note: we don't bother to save the optimizer state + { + "model_config": model_config_kwargs, + } + ) + print(f"βœ… Saved model checkpoint to {checkpoint_dir}") + +# Log to report +from nanochat.report import get_report +get_report().log(section="Chat RL", data=[ + user_config, # CLI args +]) + +wandb_run.finish() # wandb run finish +compute_cleanup() diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py new file mode 100644 index 0000000..8389deb --- /dev/null +++ b/scripts/chat_sft.py @@ -0,0 +1,281 @@ +""" +Finetune a base model to be a chat model. +Run on one GPU e.g. for debugging: + +python -m scripts.chat_sft + +Or torchrun for training: + +torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft +""" + +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import copy + +import wandb +import torch +import torch.distributed as dist + +from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb +from nanochat.checkpoint_manager import load_model +from nanochat.checkpoint_manager import save_checkpoint +from nanochat.engine import Engine +from scripts.chat_eval import run_chat_eval + +from tasks.common import TaskMixture, TaskSequence +from tasks.mmlu import MMLU +from tasks.arc import ARC +from tasks.gsm8k import GSM8K +from tasks.humaneval import HumanEval +from tasks.smoltalk import SmolTalk + +# ----------------------------------------------------------------------------- +# SFT Hyperparameters +run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +# input model options +source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) +model_tag = None # model tag to load the model from (base model or midtrained model) +step = None # step to load the model from (base model or midtrained model) +# compute/precision +dtype = "bfloat16" +device_batch_size = 4 # max to avoid OOM +# optimization +num_epochs = 1 +max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations) +target_examples_per_step = 32 +unembedding_lr = 0.004 +embedding_lr = 0.2 +matrix_lr = 0.02 +weight_decay = 0.0 +init_lr_frac = 0.02 +# evaluation and logging there of +eval_every = 100 +eval_steps = 100 +eval_metrics_every = 200 +# now allow CLI to override the settings via the configurator lol +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging +# ----------------------------------------------------------------------------- + +# Compute init +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +master_process = ddp_rank == 0 +dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) + +# wandb logging init +use_dummy_wandb = run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) + +# Load the model and tokenizer +model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) +orig_model = model # original, uncompiled model +# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs +engine = Engine(model, tokenizer) # will be used for inline model evaluation only + +# ----------------------------------------------------------------------------- +# Task data mixture we'll train on + +train_ds = TaskMixture([ + ARC(subset="ARC-Easy", split="train"), # 2.3K rows + ARC(subset="ARC-Challenge", split="train"), # 1.1K rows + GSM8K(subset="main", split="train"), # 8K rows + SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk +]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows +val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) + +# ----------------------------------------------------------------------------- +# DataLoader + +def sft_data_generator(dataset, batch_size): + pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss + # prepares a list of tokenized conversations into a batch and yields + def collate_and_yield(batch): + nrows = len(batch) + ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 + inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) + targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index + for i, (ids, mask) in enumerate(batch): + n = len(ids) + ids_tensor = torch.tensor(ids, dtype=torch.long) + inputs[i, :n-1] = ids_tensor[:-1] + # recall -1 is the ignore index, so mask out targets where mask is 0 + row_targets = ids_tensor[1:] + # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok + mask_tensor = torch.tensor(mask[1:], dtype=torch.long) + row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 + targets[i, :n-1] = row_targets + inputs = inputs.to(device) # move to device + targets = targets.to(device) + return inputs, targets + # iterates over the dataset in epochs, tokenizes + batch = [] + while True: + for i in range(ddp_rank, len(dataset), ddp_world_size): + doc = dataset[i] + ids, mask = tokenizer.render_conversation(doc) + batch.append((ids, mask)) + if len(batch) == batch_size: + yield collate_and_yield(batch) + batch = [] + +examples_per_step = device_batch_size * ddp_world_size +print0(f"Target examples per step: {target_examples_per_step}") +print0(f"Device batch size: {device_batch_size}") +print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") +assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" +grad_accum_steps = target_examples_per_step // examples_per_step +print0(f"=> Setting grad accum steps: {grad_accum_steps}") + +num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs +if max_iterations >= 0 and num_iterations > max_iterations: + print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}") + num_iterations = max_iterations +train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) +build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) + +# ----------------------------------------------------------------------------- +# Initialize the Optimizer + +optimizers = model.setup_optimizers( + unembedding_lr=unembedding_lr, + embedding_lr=embedding_lr, + matrix_lr=matrix_lr, + weight_decay=weight_decay, +) +# Set the initial learning rate as a fraction of the base learning rate +for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["lr"] * init_lr_frac + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + +# ----------------------------------------------------------------------------- +# Training loop + +# Learning rate scheduler +def get_lr_multiplier(it): + lrm = 1.0 - it / num_iterations + return lrm + +# Go! +step = 0 +train_iter = iter(train_loader) +for step in range(num_iterations): + last_step = step == num_iterations - 1 + + # evaluate the validation loss + if last_step or step % eval_every == 0: + model.eval() + val_iter = iter(build_val_loader()) + losses = [] + for _ in range(eval_steps): + val_inputs, val_targets = next(val_iter) + with torch.no_grad(), autocast_ctx: + loss = model(val_inputs, val_targets) + losses.append(loss) + val_loss = torch.stack(losses).mean() # average over eval_steps + if ddp: + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks + val_loss = val_loss.item() + print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") + wandb_run.log({ + "step": step, + "val_loss": val_loss, + }) + model.train() + + # evlauate MMLU accuracy + if last_step or (step > 0 and step % eval_metrics_every == 0): + model.eval() + metrics = {} + with torch.no_grad(), autocast_ctx: + # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size + metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) + metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) + metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64) + metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64) + metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) + print0(f"Step {step:05d} | {metrics_str}") + wandb_run.log({ + "step": step, + **metrics, + }) + model.train() + + if last_step: + break + + # evaluate the gradient + num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen + for micro_step in range(grad_accum_steps): + train_inputs, train_targets = next(train_iter) + with autocast_ctx: + loss = model(train_inputs, train_targets) + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss.backward() # accumulate the gradient + num_tokens += (train_targets >= 0).sum() + if ddp: + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + + # learning rate scheduler + lrm = get_lr_multiplier(step) + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lrm + + # step the optimizers + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) + + # logging + train_loss_item = train_loss.item() + num_tokens_item = num_tokens.item() + print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") + wandb_run.log({ + "step": step, + "lrm": lrm, + "train_loss": train_loss_item, + "num_tokens": num_tokens_item, + }) + step += 1 + +# Save the model at the end of the run +if master_process: + base_dir = get_base_dir() + depth = model.config.n_layer + model_tag = f"d{depth}" # base the model tag on the depth of the base model + checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) + model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer + save_checkpoint( + checkpoint_dir, + step, + model.state_dict(), + None, # note: we don't bother to save the optimizer state + { + "step": step, + "val_loss": val_loss, + **metrics, + "model_config": model_config_kwargs, + } + ) + print(f"βœ… Saved model checkpoint to {checkpoint_dir}") + +# Log to report +from nanochat.report import get_report +get_report().log(section="Chat SFT", data=[ + user_config, # CLI args + { + "Training rows": len(train_ds), + "Number of iterations": num_iterations, + "Training loss": train_loss_item, + "Validation loss": val_loss, + }, +]) + +# Cleanup +wandb_run.finish() +compute_cleanup() diff --git a/scripts/chat_web.py b/scripts/chat_web.py new file mode 100644 index 0000000..1a4cfe2 --- /dev/null +++ b/scripts/chat_web.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +""" +Unified web chat server - serves both UI and API from a single FastAPI instance. +Run with: python web_chat.py +Then open http://localhost:8000 in your browser. +""" + +import argparse +import json +import os +import torch +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse +from pydantic import BaseModel +from typing import List, Optional, AsyncGenerator + +from nanochat.common import compute_init +from nanochat.checkpoint_manager import load_model +from nanochat.engine import Engine + +parser = argparse.ArgumentParser(description='NanoChat Web Server') +parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") +parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') +parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') +parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') +parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') +parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') +parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') +parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') +args = parser.parse_args() + +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + +class ChatMessage(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + messages: List[ChatMessage] + temperature: Optional[float] = None + max_tokens: Optional[int] = None + top_k: Optional[int] = None + stream: Optional[bool] = True + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Load model on startup.""" + print("Loading nanochat model...") + app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) + app.state.engine = Engine(app.state.model, app.state.tokenizer) + print(f"Server ready at http://localhost:{args.port}") + yield + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +@app.get("/") +async def root(): + """Serve the chat UI.""" + ui_html_path = os.path.join("nanochat", "ui.html") + with open(ui_html_path, "r") as f: + html_content = f.read() + # Replace the API_URL to use the same origin + html_content = html_content.replace( + "const API_URL = `http://${window.location.hostname}:8000`;", + "const API_URL = '';" + ) + return HTMLResponse(content=html_content) + + +@app.get("/logo.svg") +async def logo(): + """Serve the NanoChat logo for favicon and header.""" + logo_path = os.path.join("nanochat", "logo.svg") + return FileResponse(logo_path, media_type="image/svg+xml") + +async def generate_stream( + engine, + tokenizer, + tokens, + temperature=None, + max_new_tokens=None, + top_k=None +) -> AsyncGenerator[str, None]: + """Generate assistant response with streaming.""" + temperature = temperature if temperature is not None else args.temperature + max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens + top_k = top_k if top_k is not None else args.top_k + + assistant_end = tokenizer.encode_special("<|assistant_end|>") + bos = tokenizer.get_bos_token_id() + + with autocast_ctx: + for token_column, token_masks in engine.generate( + tokens, + num_samples=1, + max_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k + ): + token = token_column[0] + + if token == assistant_end or token == bos: + break + + token_text = tokenizer.decode([token]) + yield f"data: {json.dumps({'token': token_text})}\n\n" + + yield f"data: {json.dumps({'done': True})}\n\n" + +@app.post("/chat/completions") +async def chat_completions(request: ChatRequest): + """Chat completion endpoint with streaming.""" + engine = app.state.engine + tokenizer = app.state.tokenizer + + # Build conversation tokens + bos = tokenizer.get_bos_token_id() + user_start = tokenizer.encode_special("<|user_start|>") + user_end = tokenizer.encode_special("<|user_end|>") + assistant_start = tokenizer.encode_special("<|assistant_start|>") + assistant_end = tokenizer.encode_special("<|assistant_end|>") + + conversation_tokens = [bos] + for message in request.messages: + if message.role == "user": + conversation_tokens.append(user_start) + conversation_tokens.extend(tokenizer.encode(message.content)) + conversation_tokens.append(user_end) + elif message.role == "assistant": + conversation_tokens.append(assistant_start) + conversation_tokens.extend(tokenizer.encode(message.content)) + conversation_tokens.append(assistant_end) + + conversation_tokens.append(assistant_start) + + if request.stream: + return StreamingResponse( + generate_stream( + engine, + tokenizer, + conversation_tokens, + temperature=request.temperature, + max_new_tokens=request.max_tokens, + top_k=request.top_k + ), + media_type="text/event-stream" + ) + else: + # Non-streaming response + temperature = request.temperature if request.temperature is not None else args.temperature + max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens + top_k = request.top_k if request.top_k is not None else args.top_k + + with autocast_ctx: + result_tokens, masks = engine.generate_batch( + conversation_tokens, + num_samples=1, + max_tokens=max_tokens, + temperature=temperature, + top_k=top_k + )[0] + + response_tokens = result_tokens[len(conversation_tokens):] + response_text = tokenizer.decode(response_tokens) + return { + "choices": [{ + "message": { + "role": "assistant", + "content": response_text + }, + "finish_reason": "stop" + }] + } + +@app.get("/health") +async def health(): + """Health check endpoint.""" + return { + "status": "ok", + "ready": hasattr(app.state, 'model') and app.state.model is not None + } + +if __name__ == "__main__": + import uvicorn + print(f"Starting NanoChat Web Server") + print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") + uvicorn.run(app, host=args.host, port=args.port) diff --git a/scripts/mid_train.py b/scripts/mid_train.py new file mode 100644 index 0000000..202682d --- /dev/null +++ b/scripts/mid_train.py @@ -0,0 +1,289 @@ +""" +Midtrain the model. Same as pretraining but simpler. +Run as: + +python -m scripts.mid_train + +Or torchrun for training: + +torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 +""" + +from collections import deque +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import time +import wandb +import torch + +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir +from nanochat.tokenizer import get_token_bytes +from nanochat.checkpoint_manager import save_checkpoint +from nanochat.loss_eval import evaluate_bpb +from nanochat.checkpoint_manager import load_model +import torch.distributed as dist + +from tasks.common import TaskMixture +from tasks.gsm8k import GSM8K +from tasks.mmlu import MMLU +from tasks.smoltalk import SmolTalk + +# ----------------------------------------------------------------------------- +run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +model_tag = None # model tag to load the model from (base model or midtrained model) +step = None # step to load the model from (base model or midtrained model) +dtype = "bfloat16" +max_seq_len = 2048 +device_batch_size = 32 +unembedding_lr = 0.004 +embedding_lr = 0.2 +matrix_lr = 0.02 +init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate +weight_decay = 0.0 +final_lr_frac = 0.0 # final LR is this fraction of the initial LR +eval_every = 150 +eval_tokens = 20*524288 +total_batch_size = 524288 +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging +# ----------------------------------------------------------------------------- + +# Compute init +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +master_process = ddp_rank == 0 +dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) + +# wandb logging init +use_dummy_wandb = run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config) + +# Load the model and tokenizer +model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) +pretrain_batch_size = meta.get("device_batch_size", None) +if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: + print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") +orig_model = model +model = torch.compile(model, dynamic=False) +depth = model.config.n_layer +num_flops_per_token = model.estimate_flops() +tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +token_bytes = get_token_bytes(device=device) + +# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) +optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) +adamw_optimizer, muon_optimizer = optimizers +# Override the initial learning rate as a fraction of the base learning rate +for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["lr"] * init_lr_frac + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + +# Midtraining data mixture and DataLoader +base_dir = get_base_dir() +train_dataset = TaskMixture([ + SmolTalk(split="train"), # 460K rows of general conversations + MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE + GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use +]) # total: 460K + 100K + 8K = 568K rows +val_dataset = TaskMixture([ + SmolTalk(split="test"), # 24K rows in test set + MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios + GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios +]) # total: 24K + 14K + 1.32K ~= 39K rows +# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) +# A big problem is that we don't know the final num_iterations in advance. So we create +# these two global variables and update them from within the data generator. +last_step = False # we will toggle this to True when we reach the end of the dataset +approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch +def mid_data_generator(split): + global last_step, approx_progress + assert split in {"train", "val"}, "split must be 'train' or 'val'" + dataset = train_dataset if split == "train" else val_dataset + dataset_size = len(dataset) + assert dataset_size > 0 + needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets + token_buffer = deque() + scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) + cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents + while True: + # Accumulate enough tokens for one iteration before yielding + while len(token_buffer) < needed_tokens: + conversation = dataset[cursor] + ids, _ = tokenizer.render_conversation(conversation) + token_buffer.extend(ids) + cursor += ddp_world_size + if cursor >= dataset_size: + cursor -= dataset_size # wrap around for another epoch + if split == "train": + last_step = True # toggle last_step to True, which will terminate the training loop + # Build up inputs/targets and yield + for i in range(needed_tokens): + scratch[i] = token_buffer.popleft() + inputs_cpu = scratch[:-1].to(dtype=torch.int32) + targets_cpu = scratch[1:] + inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) + targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + if split == "train": + approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset + yield inputs, targets + +train_loader = mid_data_generator("train") +build_val_loader = lambda: mid_data_generator("val") +progress = 0 # will go from 0 to 1 over the course of the epoch + +# Learning rate scheduler +def get_lr_multiplier(progress): + return progress * 1.0 + (1 - progress) * final_lr_frac + +# Momentum scheduler for Muon optimizer +def get_muon_momentum(it): + frac = min(it / 300, 1) + momentum = (1 - frac) * 0.85 + frac * 0.95 + return momentum + +# ----------------------------------------------------------------------------- +# Training loop +x, y = next(train_loader) # prefetch the very first batch of data +min_val_bpb = float("inf") +smooth_train_loss = 0 # EMA of training loss +ema_beta = 0.9 # EMA decay factor +total_training_time = 0 # total wall-clock time of training +step = 0 +while True: + flops_so_far = num_flops_per_token * total_batch_size * step + + # Synchronize last_step across all ranks to avoid hangs in the distributed setting + if ddp: + last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) + dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) + last_step = bool(last_step_tensor.item()) + + # once in a while: evaluate the val bpb (all ranks participate) + if last_step or step % eval_every == 0: + model.eval() + val_loader = build_val_loader() + eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) + with autocast_ctx: + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") + if val_bpb < min_val_bpb: + min_val_bpb = val_bpb + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + }) + model.train() + + # save checkpoint at the end of the run (only on master process) + if master_process and last_step: + output_dirname = f"d{depth}" # e.g. d12 + checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), + [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly + { + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": { + "sequence_len": max_seq_len, + "vocab_size": tokenizer.get_vocab_size(), + "n_layer": depth, + "n_head": model.config.n_head, + "n_kv_head": model.config.n_kv_head, + "n_embd": model.config.n_embd, + }, + "user_config": user_config, # inputs to the training script + } + ) + + if last_step: + break + + # ------------------------------------------------------------------------- + # single training step + # evaluate the gradient + torch.cuda.synchronize() + t0 = time.time() + for micro_step in range(grad_accum_steps): + with autocast_ctx: + loss = model(x, y) + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss.backward() + x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + progress = max(progress, approx_progress) # only increase progress monotonically + # step the optimizers + lrm = get_lr_multiplier(progress) + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lrm + muon_momentum = get_muon_momentum(step) + for group in muon_optimizer.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) + torch.cuda.synchronize() + t1 = time.time() + dt = t1 - t0 + # ------------------------------------------------------------------------- + + # State + step += 1 + + # logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + pct_done = 100 * progress + tok_per_sec = int(world_tokens_per_fwdbwd / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt + promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity + mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + if step > 10: + total_training_time += dt # only count the time after the first 10 steps + print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + if step % 10 == 0: + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + }) + +# print a few more stats +print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +print0(f"Total training time: {total_training_time/60:.2f}m") +print0(f"Minimum validation bpb: {min_val_bpb:.4f}") + +# Log to report +from nanochat.report import get_report +get_report().log(section="Midtraining", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } +]) + +# cleanup +wandb_run.finish() # wandb run finish +compute_cleanup() diff --git a/scripts/tok_eval.py b/scripts/tok_eval.py new file mode 100644 index 0000000..9233d71 --- /dev/null +++ b/scripts/tok_eval.py @@ -0,0 +1,265 @@ +""" +Evaluate compression ratio of the tokenizer. +""" + +from nanochat.tokenizer import get_tokenizer, RustBPETokenizer +from nanochat.dataset import parquets_iter_batched + +# Random text I got from a random website this morning +news_text = r""" +(Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025. + +While USDA announced a risk-based phased port re-opening strategy for cattle, bison, and equine from Mexico beginning as early as July 7, 2025, this newly reported NWS case raises significant concern about the previously reported information shared by Mexican officials and severely compromises the outlined port reopening schedule of five ports from July 7-September 15. Therefore, in order to protect American livestock and our nation’s food supply, Secretary Rollins has ordered the closure of livestock trade through southern ports of entry effective immediately. + +β€œThe United States has promised to be vigilant β€” and after detecting this new NWS case, we are pausing the planned port reopening’s to further quarantine and target this deadly pest in Mexico. We must see additional progress combatting NWS in Veracruz and other nearby Mexican states in order to reopen livestock ports along the Southern border,” said U.S. Secretary of Agriculture Brooke L. Rollins. β€œThanks to the aggressive monitoring by USDA staff in the U.S. and in Mexico, we have been able to take quick and decisive action to respond to the spread of this deadly pest.” +""".strip() + +# Random Korean text (to test non-English compression) +korean_text = r""" +μ •μ§ν•œ 사싀 μœ„μ—, κ³΅μ •ν•œ μ‹œμ„ μ„ λ”ν•˜λ‹€ +Herald Korea Times + +ν—€λŸ΄λ“œμ½”λ¦¬μ•„νƒ€μž„μ¦ˆλŠ” μ •μΉ˜, 경제, μ‚¬νšŒ, λ¬Έν™” λ“± ν•œκ΅­ μ‚¬νšŒ μ „λ°˜μ˜ μ£Όμš” 이슈λ₯Ό 심도 있게 λ‹€λ£¨λŠ” μ’…ν•© 온라인 μ‹ λ¬Έμ‚¬μž…λ‹ˆλ‹€. + +μš°λ¦¬λŠ” λ‹¨μˆœνžˆ λ‰΄μŠ€λ₯Ό μ „λ‹¬ν•˜λŠ” 것이 μ•„λ‹ˆλΌ, 사싀(Fact)에 κΈ°λ°˜ν•œ μ–‘μΈ‘μ˜ μ‹œκ°μ„ κ· ν˜• 있게 μ‘°λͺ…ν•˜λ©°, λ…μž μ—¬λŸ¬λΆ„μ΄ 슀슀둜 νŒλ‹¨ν•  수 μžˆλŠ” β€˜μ •λ³΄μ˜ κ· ν˜•β€™μ„ μ œκ³΅ν•©λ‹ˆλ‹€. + +ν•œκ΅­ μ–Έλ‘ μ˜ 였랜 문제둜 μ§€μ λ˜μ–΄ 온 μ •μΉ˜μ  편ν–₯, 이념적 μ™œκ³‘μ—μ„œ λ²—μ–΄λ‚˜ +였직 정직함과 곡정함을 μ›μΉ™μœΌλ‘œ μ‚ΌλŠ” 언둠을 μ§€ν–₯ν•©λ‹ˆλ‹€. +μ–΄λŠ ν•œμͺ½μ˜ μ£Όμž₯λ§Œμ„ ν™•λŒ€ν•˜κ±°λ‚˜ 감좔지 μ•Šκ³ , +**λͺ¨λ“  μŸμ μ— λŒ€ν•΄ β€˜λ¬΄μ—‡μ΄ μŸμ μΈμ§€β€™, β€˜λˆ„κ°€ 무엇을 μ£Όμž₯ν•˜λŠ”μ§€β€™, β€˜μ‚¬μ‹€μ€ 무엇인지’**λ₯Ό λͺ…ν™•νžˆ μ „λ‹¬ν•˜λŠ” 데 μ§‘μ€‘ν•©λ‹ˆλ‹€. +""".strip() + +# Random piece of code +code_text = r""" +class BasicTokenizer(Tokenizer): + + def __init__(self): + super().__init__() + + def train(self, text, vocab_size, verbose=False): + assert vocab_size >= 256 + num_merges = vocab_size - 256 + + # input text preprocessing + text_bytes = text.encode("utf-8") # raw bytes + ids = list(text_bytes) # list of integers in range 0..255 + + # iteratively merge the most common pairs to create new tokens + merges = {} # (int, int) -> int + vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes + for i in range(num_merges): + # count up the number of times every consecutive pair appears + stats = get_stats(ids) + # find the pair with the highest count + pair = max(stats, key=stats.get) + # mint a new token: assign it the next available id + idx = 256 + i + # replace all occurrences of pair in ids with idx + ids = merge(ids, pair, idx) + # save the merge + merges[pair] = idx + vocab[idx] = vocab[pair[0]] + vocab[pair[1]] + # prints + if verbose: + print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") +""".strip() + +math_text = r""" +\documentclass[12pt]{article} +\usepackage{amsmath,amsthm,amssymb} +\usepackage[margin=1in]{geometry} + +\newtheorem{theorem}{Theorem} +\newtheorem*{remark}{Remark} + +\begin{document} + +\begin{center} +{\Large A Cute Identity: The Sum of Cubes is a Square} +\end{center} + +\begin{theorem} +For every integer $n \ge 1$, +\[ +\sum_{k=1}^{n} k^{3} \;=\; \left(\frac{n(n+1)}{2}\right)^{2}. +\] +\end{theorem} + +\begin{proof}[Proof 1 (Induction)] +Let $S(n) = \sum_{k=1}^{n} k^3$. For $n=1$, $S(1)=1=(1\cdot 2/2)^2$, so the base case holds. + +Assume $S(n)=\big(\tfrac{n(n+1)}{2}\big)^2$ for some $n\ge 1$. +Then +\[ +S(n+1) += S(n) + (n+1)^3 += \left(\frac{n(n+1)}{2}\right)^2 + (n+1)^3. +\] +Factor out $(n+1)^2$: +\[ +S(n+1) += (n+1)^2\left( \frac{n^2}{4} + (n+1) \right) += (n+1)^2\left( \frac{n^2 + 4n + 4}{4} \right) += (n+1)^2\left( \frac{(n+2)^2}{4} \right). +\] +Thus +\[ +S(n+1)=\left(\frac{(n+1)(n+2)}{2}\right)^2, +\] +which matches the claimed formula with $n$ replaced by $n+1$. By induction, the identity holds for all $n\ge 1$. +\end{proof} + +\begin{proof}[Proof 2 (Algebraic telescoping)] +Recall the binomial identity +\[ +(k+1)^4 - k^4 = 4k^3 + 6k^2 + 4k + 1. +\] +Summing both sides from $k=0$ to $n$ telescopes: +\[ +(n+1)^4 - 0^4 += \sum_{k=0}^{n}\big(4k^3 + 6k^2 + 4k + 1\big) += 4\sum_{k=1}^{n}k^3 + 6\sum_{k=1}^{n}k^2 + 4\sum_{k=1}^{n}k + (n+1). +\] +Using the standard sums +\[ +\sum_{k=1}^{n}k = \frac{n(n+1)}{2} +\quad\text{and}\quad +\sum_{k=1}^{n}k^2 = \frac{n(n+1)(2n+1)}{6}, +\] +solve for $\sum_{k=1}^{n}k^3$ to get +\[ +\sum_{k=1}^{n}k^3 = \left(\frac{n(n+1)}{2}\right)^2. +\] +\end{proof} + +\begin{remark} +Geometrically, the identity says: ``adding up $1^3,2^3,\dots,n^3$ builds a perfect square’’—namely the square of the $n$th triangular number. This is why one sometimes calls it the \emph{sum-of-cubes is a square} phenomenon. +\end{remark} + +\end{document} +""".strip() + +science_text = r""" +Photosynthesis is a photochemical energy transduction process in which light-harvesting pigment–protein complexes within the thylakoid membranes of oxygenic phototrophs absorb photons and initiate charge separation at the reaction center, driving the linear electron transport chain from water to NADP⁺ via photosystem II, the cytochrome b₆f complex, and photosystem I, concomitantly generating a trans-thylakoid proton motive force utilized by chloroplastic ATP synthase. The light-dependent reactions produce ATP and NADPH, which fuel the Calvin–Benson–Bassham cycle in the stroma, wherein ribulose-1,5-bisphosphate is carboxylated by ribulose-1,5-bisphosphate carboxylase/oxygenase (RuBisCO) to form 3-phosphoglycerate, subsequently reduced and regenerated through a series of enzymatic steps, enabling net assimilation of COβ‚‚ into triose phosphates and ultimately carbohydrates. This process is tightly regulated by photoprotective mechanisms, redox feedback, and metabolite flux, representing a central biochemical pathway coupling solar energy capture to the biosphere’s primary productivity. +""".strip() + +# The tokenizer was trained on data from earlier shards, so it has seen this data +train_docs = next(parquets_iter_batched(split="train")) +train_text = "\n".join(train_docs) +val_docs = next(parquets_iter_batched(split="val")) +val_text = "\n".join(val_docs) + +all_text = [ + ("news", news_text), + ("korean", korean_text), + ("code", code_text), + ("math", math_text), + ("science", science_text), + ("fwe-train", train_text), +] +if val_text: + all_text.append(("fwe-val", val_text)) + +# Try out current default compared to GPT-2 and GPT-4 tokenizers +tokenizer_results = {} +vocab_sizes = {} + +for tokenizer_name in ["gpt2", "gpt4", "ours"]: + + if tokenizer_name == "gpt2": + tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer + elif tokenizer_name == "gpt4": + tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer + else: + tokenizer = get_tokenizer() + + vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size() + tokenizer_results[tokenizer_name] = {} + + for name, text in all_text: + encoded = tokenizer.encode(text) + decoded = tokenizer.decode(encoded) + assert decoded == text + + encoded_bytes = text.encode('utf-8') + ratio = len(encoded_bytes) / len(encoded) + tokenizer_results[tokenizer_name][name] = { + 'bytes': len(encoded_bytes), + 'tokens': len(encoded), + 'ratio': ratio + } + +# ANSI color codes +GREEN = '\033[92m' +RED = '\033[91m' +RESET = '\033[0m' + +# Print vocab sizes +print(f"\nVocab sizes:") +print(f"GPT-2: {vocab_sizes['gpt2']}") +print(f"GPT-4: {vocab_sizes['gpt4']}") +print(f"Ours: {vocab_sizes['ours']}") + +def print_comparison(baseline_name, baseline_results, ours_results, all_text): + """Print comparison table between baseline tokenizer and ours.""" + print(f"\nComparison with {baseline_name}:") + print("=" * 95) + print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}") + print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}") + print("-" * 95) + + for name, text in all_text: + baseline_data = baseline_results[name] + ours_data = ours_results[name] + + # Calculate relative difference (positive means ours is better, negative means worse) + # Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens + relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 + + # Determine which has better compression (higher ratio = better) + if baseline_data['ratio'] > ours_data['ratio']: + baseline_color, ours_color = GREEN, RED + better = baseline_name + diff_color = RED + elif ours_data['ratio'] > baseline_data['ratio']: + baseline_color, ours_color = RED, GREEN + better = "Ours" + diff_color = GREEN + else: + baseline_color, ours_color = "", "" + better = "Tie" + diff_color = "" + + print(f"{name:<10} {baseline_data['bytes']:<8} " + f"{baseline_color}{baseline_data['tokens']:<7}{RESET} " + f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} " + f"{ours_color}{ours_data['tokens']:<7}{RESET} " + f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} " + f"{diff_color}{relative_diff:+7.1f}%{RESET} " + f"{better:<10}") + +# Print comparisons +print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text) +print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], all_text) + +# Log to report +from nanochat.report import get_report +lines = [] +for baseline_name in ["GPT-2", "GPT-4"]: + baseline_key = baseline_name.lower().replace('-', '') + baseline_results = tokenizer_results[baseline_key] + ours_results = tokenizer_results['ours'] + lines.append(f"### Comparison with {baseline_name}") + lines.append("") + lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |") + lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|") + for name, text in all_text: + baseline_data = baseline_results[name] + ours_data = ours_results[name] + relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 + lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |") + lines.append("") +report_markdown = "\n".join(lines) +get_report().log(section="Tokenizer evaluation", data=[ + report_markdown, +]) diff --git a/scripts/tok_train.py b/scripts/tok_train.py new file mode 100644 index 0000000..c2faf17 --- /dev/null +++ b/scripts/tok_train.py @@ -0,0 +1,106 @@ +""" +Train a tokenizer using the HuggingFace Tokenizers library. +In the style of GPT-4 tokenizer. +""" +import os +import time +import argparse +import torch +from nanochat.tokenizer import RustBPETokenizer +from nanochat.common import get_base_dir +from nanochat.dataset import parquets_iter_batched + +# ----------------------------------------------------------------------------- +# Parse command line arguments + +parser = argparse.ArgumentParser(description='Train a BPE tokenizer') +parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') +parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') +parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') +args = parser.parse_args() +print(f"max_chars: {args.max_chars:,}") +print(f"doc_cap: {args.doc_cap:,}") +print(f"vocab_size: {args.vocab_size:,}") + +# ----------------------------------------------------------------------------- +# Text iterator + +def text_iterator(): + """ + 1) Flatten the batches into a single iterator + 2) Crop every document to args.doc_cap characters + 3) Break when we've seen args.max_chars characters + """ + nchars = 0 + for batch in parquets_iter_batched(split="train"): + for doc in batch: + doc_text = doc + if len(doc_text) > args.doc_cap: + doc_text = doc_text[:args.doc_cap] + nchars += len(doc_text) + yield doc_text + if nchars > args.max_chars: + return +text_iter = text_iterator() + +# ----------------------------------------------------------------------------- +# Train the tokenizer +t0 = time.time() +tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size) +t1 = time.time() +train_time = t1 - t0 +print(f"Training time: {train_time:.2f}s") + +# ----------------------------------------------------------------------------- +# Save the tokenizer to disk +base_dir = get_base_dir() +tokenizer_dir = os.path.join(base_dir, "tokenizer") +tokenizer.save(tokenizer_dir) + +# ----------------------------------------------------------------------------- +# Quick inline sanity check +test_text = """Hello world! This is a test. +Numbers: 123, 4567, 89 +Contractions: I'm, you're, it's +Special chars: @#$%^&*() +Unicode: δ½ ε₯½δΈ–η•Œ 🌍""" +encoded = tokenizer.encode(test_text) +decoded = tokenizer.decode(encoded) +assert decoded == test_text + +# ----------------------------------------------------------------------------- +# One more thing: we wish to cache a mapping from token id to number of bytes of that token +# for efficient evaluation of bits per byte. Unlike the typical mean loss, this +# allows us to report a loss that is invariant to the vocab size of the tokenizer. +# The bits per byte on the validation set is then one of the primary metrics we care about. +vocab_size = tokenizer.get_vocab_size() +special_set = set(tokenizer.get_special_tokens()) +token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] +token_bytes = [] +for token_id in range(vocab_size): + token_str = token_strings[token_id] # the Python string representation of this token + if token_str in special_set: + token_bytes.append(0) # special characters are not counted + else: + id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token + token_bytes.append(id_bytes) +token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') +token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") +with open(token_bytes_path, "wb") as f: + torch.save(token_bytes, f) +print(f"Saved token_bytes to {token_bytes_path}") + +# Log to report +from nanochat.report import get_report +token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) +get_report().log(section="Tokenizer training", data=[ + vars(args), # argparse command line arguments + {"train_time": train_time}, + {"num_special_tokens": len(special_set)}, + { + "token_bytes_min": int(token_bytes_nonzero.min().item()), + "token_bytes_max": int(token_bytes_nonzero.max().item()), + "token_bytes_mean": token_bytes_nonzero.mean().item(), + "token_bytes_std": token_bytes_nonzero.std().item(), + } +]) diff --git a/speedrun.sh b/speedrun.sh new file mode 100644 index 0000000..d2498ee --- /dev/null +++ b/speedrun.sh @@ -0,0 +1,133 @@ +#!/bin/bash + +# This script is the "Best ChatGPT clone that $100 can buy", +# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour. + +# 1) Example launch (simplest): +# bash speedrun.sh +# 2) Example launch in a screen session (because the run takes ~4 hours): +# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh +# 3) Example launch with wandb logging, but see below for setting up wandb first: +# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh + +# Default intermediate artifacts directory is in ~/.cache/nanochat +export OMP_NUM_THREADS=1 +NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR + +# ----------------------------------------------------------------------------- +# Python venv setup with uv + +# install uv (if not already installed) +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh +# create a .venv local virtual environment (if it doesn't exist) +[ -d ".venv" ] || uv venv +# install the repo dependencies +uv sync +# activate venv so that `python` uses the project's venv instead of system python +source .venv/bin/activate + +# ----------------------------------------------------------------------------- +# wandb setup +# If you wish to use wandb for logging (it's nice!, recommended). +# 1) Make sure to first log in to wandb, e.g. run: +# `wandb login` +# 2) Set the WANDB_RUN environment variable when running this script, e.g.: +# `WANDB_RUN=d26 bash speedrun.sh` +if [ -z "$WANDB_RUN" ]; then + # by default use "dummy" : it's handled as a special case, skips logging to wandb + WANDB_RUN=dummy +fi + +# ----------------------------------------------------------------------------- +# During the course of the run, we will be writing markdown reports to the report/ +# directory in the base dir. This command clears it out and writes a header section +# with a bunch of system info and a timestamp that marks the start of the run. +python -m nanochat.report reset + +# ----------------------------------------------------------------------------- +# Tokenizer + +# Install Rust / Cargo +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" + +# Build the rustbpe Tokenizer +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml + +# Download the first ~2B characters of pretraining dataset +# look at dev/repackage_data_reference.py for details on how this data was prepared +# each data shard is ~250M chars +# so we download 2e9 / 250e6 = 8 data shards at this point +# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk +python -m nanochat.dataset -n 8 +# Immediately also kick off downloading more shards in the background while tokenizer trains +# See comment below for why 240 is the right number here +python -m nanochat.dataset -n 240 & +DATASET_DOWNLOAD_PID=$! +# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data +python -m scripts.tok_train --max_chars=2000000000 +# evaluate the tokenizer (report compression ratio etc.) +python -m scripts.tok_eval + +# ----------------------------------------------------------------------------- +# Base model (pretraining) + +# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB) +EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip +if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then + curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL + unzip -q eval_bundle.zip + rm eval_bundle.zip + mv eval_bundle $NANOCHAT_BASE_DIR +fi + +# The d20 model is 561M parameters. +# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. +# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. +# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. +# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk. +# (The total number of shards available in the entire dataset is 1822.) +echo "Waiting for dataset download to complete..." +wait $DATASET_DOWNLOAD_PID + +# pretrain the d20 model +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN +# evaluate the model on a larger chunk of train/val data and draw some samples +torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +# evaluate the model on CORE tasks +torchrun --standalone --nproc_per_node=8 -m scripts.base_eval + +# ----------------------------------------------------------------------------- +# Midtraining (teach the model conversation special tokens, tool use, multiple choice) + +# run midtraining and eval the model +torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid + +# ----------------------------------------------------------------------------- +# Supervised Finetuning (domain adaptation to each sequence all by itself per row) + +# train sft and re-eval right away (should see a small bump) +torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft + +# chat with the model over CLI! Leave out the -p to chat interactively +# python -m scripts.chat_cli -p "Why is the sky blue?" + +# even better, chat with your model over a pretty WebUI ChatGPT style +# python -m scripts.chat_web + +# ----------------------------------------------------------------------------- +# Reinforcement Learning. Optional, and currently only on GSM8K +# (optional) + +# run reinforcement learning +# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN +# eval the RL model only on GSM8K +# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K + +# ----------------------------------------------------------------------------- +# Generate the full report by putting together all the sections +# report.md is the output and will be copied to current directory for convenience +python -m nanochat.report generate diff --git a/tasks/arc.py b/tasks/arc.py new file mode 100644 index 0000000..862cca9 --- /dev/null +++ b/tasks/arc.py @@ -0,0 +1,48 @@ +""" +The ARC dataset from Allen AI. +https://huggingface.co/datasets/allenai/ai2_arc +""" + +from datasets import load_dataset +from tasks.common import Task, render_mc + +class ARC(Task): + + def __init__(self, subset, split, **kwargs): + super().__init__(**kwargs) + assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge" + assert split in ["train", "validation", "test"], "ARC split must be train|validation|test" + self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42) + + @property + def eval_type(self): + return 'categorical' + + def num_examples(self): + return len(self.ds) + + def get_example(self, index): + row = self.ds[index] + question = row["question"] # the question text + choices = row["choices"]["text"] # the text of each choice + answer_string = row["answerKey"] # e.g. "A", "B", "C", "D" + letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"] + assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check + # create and return the Conversation object + user_message = render_mc(question, letters, choices) + messages = [ + {"role": "user", "content": user_message}, + {"role": "assistant", "content": answer_string} + ] + conversation = { + "messages": messages, + "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters + } + return conversation + + def evaluate(self, conversation, assistant_response): + # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true + # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. + assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}" + assistant_message = conversation['messages'][-1]['content'] # e.g. "A" + return assistant_response == assistant_message diff --git a/tasks/common.py b/tasks/common.py new file mode 100644 index 0000000..dcd2e91 --- /dev/null +++ b/tasks/common.py @@ -0,0 +1,147 @@ +""" +Base class for all Tasks. +A Task is basically a dataset of conversations, together with some +metadata and often also evaluation criteria. +Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk. +""" + +import random + +class Task: + """ + Base class of a Task. Allows for lightweight slicing of the underlying dataset. + """ + + def __init__(self, start=0, stop=None, step=1): + # allows a lightweight logical view over a dataset + assert start >= 0, f"Start must be non-negative, got {start}" + assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}" + assert step >= 1, f"Step must be strictly positive, got {step}" + self.start = start + self.stop = stop # could be None here + self.step = step + + @property + def eval_type(self): + # one of 'generative' | 'categorical' + raise NotImplementedError + + def num_examples(self): + raise NotImplementedError + + def get_example(self, index): + raise NotImplementedError + + def __len__(self): + start = self.start + stop = self.num_examples() if self.stop is None else self.stop + step = self.step + span = stop - start + num = (span + step - 1) // step # ceil_div(span, step) + assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns + return num + + def __getitem__(self, index: int): + assert isinstance(index, int), f"Index must be an integer, got {type(index)}" + physical_index = self.start + index * self.step + conversation = self.get_example(physical_index) + return conversation + + def evaluate(self, problem, completion): + raise NotImplementedError + + +class TaskMixture(Task): + """ + For SFT Training it becomes useful to train on a tax mixture of datasets. + Fun trick: if you wish to oversample any task, just pass it in multiple times in the list. + """ + + def __init__(self, tasks, **kwargs): + super().__init__(**kwargs) + # tasks is a list of Task objects + self.tasks = tasks + self.lengths = [len(task) for task in self.tasks] + self.num_conversations = sum(self.lengths) + # Build list of all (task_idx, local_idx) pairs + self.index_map = [] + for task_idx, task_length in enumerate(self.lengths): + for local_idx in range(task_length): + self.index_map.append((task_idx, local_idx)) + # Deterministically shuffle to mix tasks throughout training + rng = random.Random(42) + rng.shuffle(self.index_map) + # Note: this is not the most elegant or best solution, but it's ok for now + + def num_examples(self): + return self.num_conversations + + def get_example(self, index): + """ + Access conversations according to a deterministic shuffle of all examples. + This ensures tasks are mixed throughout training, regardless of dataset size. + """ + assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations" + task_idx, local_idx = self.index_map[index] + return self.tasks[task_idx][local_idx] + + +class TaskSequence(Task): + """ + For SFT Training sometimes we want to sequentially train on a list of tasks. + This is useful for cases that require a training curriculum. + """ + + def __init__(self, tasks, **kwargs): + super().__init__(**kwargs) + self.tasks = tasks + self.lengths = [len(task) for task in self.tasks] + self.num_conversations = sum(self.lengths) + + def num_examples(self): + return self.num_conversations + + def get_example(self, index): + assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations" + for task_idx, task_length in enumerate(self.lengths): + if index < task_length: + return self.tasks[task_idx][index] + index -= task_length + + +def render_mc(question, letters, choices): + """ + The common multiple choice rendering format we will use. + + Note two important design decisions: + 1) + Bigger models don't care as much, but smaller models prefer to have + the letter *after* the choice, which results in better binding. + 2) + There is no whitespace between the delimiter (=) and the letter. + This is actually critical because the tokenizer has different token ids + for " A" vs. "A". The assistant responses will be just the letter itself, + i.e. "A", so it is important that here in the prompt it is the exact same + token, i.e. "A" with no whitespace before it. Again, bigger models don't care + about this too much, but smaller models do care about some of these details. + """ + query = f"Multiple Choice question: {question}\n" + query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)]) + query += "\nRespond only with the letter of the correct answer." + return query + + +if __name__ == "__main__": + # very lightweight test of slicing + from tasks.mmlu import MMLU + + ds = MMLU(subset="auxiliary_train", split="train") + print("Length of MMLU: ", len(ds)) + ex = ds[5] + print("5th example: ", ex) + + ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10) + print("Length of sliced MMLU[5:10]: ", len(ds)) + print("0th example of sliced MMLU: ", ds[0]) + + print("They match: ", ex == ds[0]) diff --git a/tasks/gsm8k.py b/tasks/gsm8k.py new file mode 100644 index 0000000..c05e21c --- /dev/null +++ b/tasks/gsm8k.py @@ -0,0 +1,117 @@ +""" +GSM8K evaluation. +https://huggingface.co/datasets/openai/gsm8k + +Example problem instance: + +Question: +Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? +Answer: +Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. +Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. +#### 10 + +Notice that GSM8K uses tool calls inside << >> tags. +""" + +import re +from datasets import load_dataset +from tasks.common import Task + + +GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)") +def extract_answer(completion): + """ + Extract the numerical answer after #### marker. + Follows official code for normalization: + https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28 + """ + match = GSM_RE.search(completion) + if match: + match_str = match.group(1).strip() + match_str = match_str.replace(",", "") + return match_str + return None + + +class GSM8K(Task): + + def __init__(self, subset, split, **kwargs): + super().__init__(**kwargs) + assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic" + assert split in ["train", "test"], "GSM8K split must be train|test" + self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42) + + @property + def eval_type(self): + return 'generative' + + def num_examples(self): + return len(self.ds) + + def get_example(self, index): + """ Get a single problem from the dataset. """ + row = self.ds[index] + question = row['question'] # string of the question prompt + answer = row['answer'] # string of the full solution and the answer after #### marker + # Create and return the Conversation object + # This is tricky because GSM8K uses tool calls, which we need to parse here. + assistant_message_parts = [] + parts = re.split(r'(<<[^>]+>>)', answer) + for part in parts: + if part.startswith('<<') and part.endswith('>>'): + # This is a calculator tool call + inner = part[2:-2] # Remove << >> + # Split on = to get expression and result + if '=' in inner: + expr, result = inner.rsplit('=', 1) + else: + expr, result = inner, "" + # Add the tool call as a part + assistant_message_parts.append({"type": "python", "text": expr}) + # Add the result as a part + assistant_message_parts.append({"type": "python_output", "text": result}) + else: + # Regular text in between tool calls + assistant_message_parts.append({"type": "text", "text": part}) + # No put it all together + messages = [ + {"role": "user", "content": question}, # note: simple string + {"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts) + ] + conversation = { + "messages": messages, + } + return conversation + + def evaluate(self, conversation, assistant_response): + """ + Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct) + Note that: + - the conversation has both user AND assistant message (containing the ground truth answer) + - the assistant_response is usually the alternative assistant message achieved via sampling + + TODO: Technically, assistant_response should be a Message (either a string or a list of parts) + We can handle this later possibly. For now just assume string. + """ + assert isinstance(assistant_response, str), "Assuming simple string response for now" + # First extract the ground truth answer + assistant_message = conversation['messages'][-1] + assert assistant_message['role'] == "assistant", "Last message must be from the Assistant" + assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts" + last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K + # Extract both the ground truth answer and the predicted answer + ref_num = extract_answer(last_text_part) + pred_num = extract_answer(assistant_response) + # Compare and return the success as int + is_correct = int(pred_num == ref_num) + return is_correct + + def reward(self, conversation, assistant_response): + """ + Used during RL. To keep things simple, just re-use the evaluation above. + Later this could be made more complex (e.g. format matching etc.) + """ + is_correct = self.evaluate(conversation, assistant_response) + is_correct_float = float(is_correct) + return is_correct_float diff --git a/tasks/humaneval.py b/tasks/humaneval.py new file mode 100644 index 0000000..e9dd489 --- /dev/null +++ b/tasks/humaneval.py @@ -0,0 +1,97 @@ +""" +Evaluate the Chat model on HumanEval dataset. +Btw this dataset is a misnomer and has nothing to do with humans. +It is a coding benchmark. +""" + +import re +from datasets import load_dataset +from nanochat.execution import execute_code +from tasks.common import Task + +def extract_imports(prompt): + """Extract import statements from the beginning of a code block.""" + imports = [] + for line in prompt.split('\n'): + stripped = line.strip() + if stripped.startswith('import ') or stripped.startswith('from '): + imports.append(stripped) + elif stripped and not stripped.startswith('#'): + # Stop at first non-import, non-comment line + break + return '\n'.join(imports) + +def extract_program(completion): + """ + Extract Python code from LLM completion. + + Handles various output formats: + - Code wrapped in ```python ... ``` or ``` ... ``` blocks + - Plain code without markdown blocks + - Extra text before/after code blocks + + Returns the first code block if found, otherwise returns the whole completion. + """ + # Try to find markdown code blocks (```python or just ```) + # Match ```python\n...\n``` or ```\n...\n``` + pattern = r'```(?:python)?\s*\n(.*?)\n```' + matches = re.findall(pattern, completion, re.DOTALL) + + if matches: + # Return the first code block found + return matches[0].strip() + + # No code blocks found, return the whole completion + return completion.strip() + +class HumanEval(Task): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42) + + @property + def eval_type(self): + return 'generative' + + def num_examples(self): + return len(self.ds) + + def get_example(self, index): + """ Get a single problem from the dataset. """ + row = self.ds[index] + prompt = row['prompt'] # prompts in HumanEval are the beginning of the program + solution = row['canonical_solution'] # the correct continuation of the program + entry_point = row['entry_point'] # the function to check + test = row['test'] # the test cases + complete_solution = f"{prompt}\n{solution}" + messages = [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": complete_solution}, + ] + conversation = { + "messages": messages, + "entry_point": entry_point, # needed during evaluation + "test": test, # needed during evaluation + } + return conversation + + def evaluate(self, conversation, completion): + """ Given (conversation, completion), return boolean success of the completion. """ + # the prompt will contain the imports and the function signature + imports = extract_imports(conversation['messages'][0]['content']) + # the completion will usually contain the whole function + # but not always with the needed imports, so we manually append them + completion_code = extract_program(completion) + program = ( + imports + + "\n\n" + + completion_code + + "\n\n" + + conversation['test'] + + "\n" + + f"check({conversation['entry_point']})" + ) + result = execute_code(program) + success = result.success + return success diff --git a/tasks/mmlu.py b/tasks/mmlu.py new file mode 100644 index 0000000..3ba2254 --- /dev/null +++ b/tasks/mmlu.py @@ -0,0 +1,60 @@ +""" +The MMLU dataset. +https://huggingface.co/datasets/cais/mmlu +""" + +from datasets import load_dataset +from tasks.common import Task, render_mc + +class MMLU(Task): + + letters = ('A', 'B', 'C', 'D') + groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions') + + def __init__(self, subset, split, **kwargs): + super().__init__(**kwargs) + assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train" + assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test" + if subset == "auxiliary_train": + assert split == "train", "auxiliary_train must be split into train" + self.subset = subset + self.split = split + self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42) + if subset == "auxiliary_train": + # I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper + self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train']) + + @property + def eval_type(self): + return 'categorical' + + def num_examples(self): + return len(self.ds) + + def get_example(self, index): + row = self.ds[index] + question = row["question"] # the question text + choices = row["choices"] # the text of each choice + answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D) + subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc. + assert len(choices) == 4, "MMLU should have 4 choices" + # create and return the Conversation object + user_message = render_mc(question, self.letters, choices) + assistant_message = self.letters[answer] + messages = [ + {"role": "user", "content": user_message}, + {"role": "assistant", "content": assistant_message} + ] + conversation = { + "messages": messages, + "subject": subject, # might be useful later for grouping metrics by subject + "letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters + } + return conversation + + def evaluate(self, conversation, assistant_response): + # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true + # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. + assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}" + assistant_message = conversation['messages'][-1]['content'] # e.g. "A" + return assistant_response == assistant_message diff --git a/tasks/smoltalk.py b/tasks/smoltalk.py new file mode 100644 index 0000000..b4d4f5f --- /dev/null +++ b/tasks/smoltalk.py @@ -0,0 +1,46 @@ +""" +SmolTalk by HuggingFace. Good "general" conversational dataset. +https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk +We use the "smol" version, which is more appropriate for smaller models. +""" + +from datasets import load_dataset +from tasks.common import Task + +class SmolTalk(Task): + """ smol-smoltalk dataset. train is 460K rows, test is 24K rows. """ + + def __init__(self, split, **kwargs): + super().__init__(**kwargs) + assert split in ["train", "test"], "SmolTalk split must be train|test" + self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42) + self.length = len(self.ds) + + def num_examples(self): + return self.length + + def get_example(self, index): + row = self.ds[index] + messages = row["messages"] + # --------------------------------------------------------------------- + # sanity checking asserts here + # TODO: we could remove these asserts later, for now just don't want any footguns + # there is an optional system message at the beginning + assert len(messages) >= 1 + first_message = messages[0] + if first_message["role"] == "system": + rest_messages = messages[1:] # optional system message is OK + else: + rest_messages = messages + assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages" + for i, message in enumerate(rest_messages): + # user and assistant alternate as user,assistant,user,assistant,... + expected_role = "user" if i % 2 == 0 else "assistant" + assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" + assert isinstance(message["content"], str), "Content must be a string" + # --------------------------------------------------------------------- + # create and return the Conversation object (ok to emit the system message too) + conversation = { + "messages": messages, + } + return conversation diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py new file mode 100644 index 0000000..5f95721 --- /dev/null +++ b/tests/test_rustbpe.py @@ -0,0 +1,635 @@ +""" +Comparing the training of: + +1. (very slow) Python reference implementation +2. Optimized Python implementation +3. HuggingFace tokenizers training implementation +4. Our own custom RustBPE training implementation + +All of these should calculate the same merges and produce +the same vocabulary and tokenizations. + +Finally, for inference we will use tiktoken for efficiency. +So we want to make sure we can export our rustbpe tokenizer +into tiktoken and use it for inference with identical results. + +Run with: +python -m pytest tests/test_rustbpe.py -v -s +-v is verbose, -s is show prints +""" + +import regex as re +from collections import Counter, defaultdict +import time +import rustbpe +import tiktoken +import pytest + +GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" + +# ----------------------------------------------------------------------------- +# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe + +def get_stats(ids, counts=None): + """ + Given a list of integers, return a dictionary of counts of consecutive pairs + Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} + Optionally allows to update an existing dictionary of counts + """ + counts = {} if counts is None else counts + for pair in zip(ids, ids[1:]): # iterate consecutive elements + counts[pair] = counts.get(pair, 0) + 1 + return counts + +def merge(ids, pair, idx): + """ + In the list of integers (ids), replace all consecutive occurrences + of pair with the new integer token idx + Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] + """ + newids = [] + i = 0 + while i < len(ids): + # if not at the very last position AND the pair matches, replace it + if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: + newids.append(idx) + i += 2 + else: + newids.append(ids[i]) + i += 1 + return newids + +class RegexTokenizer: + + def __init__(self, pattern=None): + """ + - pattern: optional string to override the default (GPT-4 split pattern) + - special_tokens: str -> int dictionary of special tokens + example: {'<|endoftext|>': 100257} + """ + self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern + self.merges = {} # (int, int) -> int + self.compiled_pattern = re.compile(self.pattern) + self.special_tokens = {} + self.inverse_special_tokens = {} + self.vocab = self._build_vocab() + + def _build_vocab(self): + # vocab is simply and deterministically derived from merges + vocab = {idx: bytes([idx]) for idx in range(256)} + for (p0, p1), idx in self.merges.items(): + vocab[idx] = vocab[p0] + vocab[p1] + for special, idx in self.special_tokens.items(): + vocab[idx] = special.encode("utf-8") + return vocab + + def train(self, text, vocab_size, verbose=False): + assert vocab_size >= 256 + num_merges = vocab_size - 256 + + # keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique) + ambiguous = False + + # split the text up into text chunks + text_chunks = re.findall(self.compiled_pattern, text) + + # input text preprocessing + ids = [list(ch.encode("utf-8")) for ch in text_chunks] + + # iteratively merge the most common pairs to create new tokens + merges = {} # (int, int) -> int + vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes + for i in range(num_merges): + # count the number of times every consecutive pair appears + stats = {} + for chunk_ids in ids: + # passing in stats will update it in place, adding up counts + get_stats(chunk_ids, stats) + # find the pair with the highest count + pair = max(stats, key=stats.get) + # check if the merge is ambiguous - i.e. the max value is not unique + pair_count = stats[pair] + pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count] + if len(pairs_with_max_count) > 1: + # print the top 10 pairs with their counts + # print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences") + # for print_pair, print_count in sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]: + # print(f"{print_pair}: {print_count}") + ambiguous = True + # mint a new token: assign it the next available id + idx = 256 + i + # replace all occurrences of pair in ids with idx + ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] + # save the merge + merges[pair] = idx + vocab[idx] = vocab[pair[0]] + vocab[pair[1]] + # prints + if verbose: + print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") + + # save class variables + self.merges = merges # used in encode() + self.vocab = vocab # used in decode() + return ambiguous + + def _encode_chunk(self, text_bytes): + # return the token ids + # let's begin. first, convert all bytes to integers in range 0..255 + ids = list(text_bytes) + while len(ids) >= 2: + # find the pair with the lowest merge index + stats = get_stats(ids) + pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) + # subtle: if there are no more merges available, the key will + # result in an inf for every single pair, and the min will be + # just the first pair in the list, arbitrarily + # we can detect this terminating case by a membership check + if pair not in self.merges: + break # nothing else can be merged anymore + # otherwise let's merge the best pair (lowest merge index) + idx = self.merges[pair] + ids = merge(ids, pair, idx) + return ids + + def encode_ordinary(self, text): + """Encoding that ignores any special tokens.""" + # split text into chunks of text by categories defined in regex pattern + text_chunks = re.findall(self.compiled_pattern, text) + # all chunks of text are encoded separately, then results are joined + ids = [] + for chunk in text_chunks: + chunk_bytes = chunk.encode("utf-8") # raw bytes + chunk_ids = self._encode_chunk(chunk_bytes) + ids.extend(chunk_ids) + return ids + +# ----------------------------------------------------------------------------- +# Faster Python tokenizer, optimized version of the reference tokenizer + +def fast_merge_inplace(ids, pair, idx): + """ + In the list of integers (ids), replace all consecutive occurrences + of pair with the new integer token idx in place + Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] + """ + # Find all positions where the pair occurs + i = 0 + while i < len(ids) - 1: + if ids[i] == pair[0] and ids[i+1] == pair[1]: + ids[i] = idx + ids.pop(i+1) + else: + i += 1 + return ids + + +class FastRegexTokenizer: + + def __init__(self, pattern=None): + """ + - pattern: optional string to override the default (GPT-4 split pattern) + - special_tokens: str -> int dictionary of special tokens + example: {'<|endoftext|>': 100257} + """ + self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern + self.compiled_pattern = re.compile(self.pattern) + self.special_tokens = {} + self.inverse_special_tokens = {} + self.merges = {} + self.vocab = self._build_vocab() + + def _build_vocab(self): + # vocab is simply and deterministically derived from merges + vocab = {idx: bytes([idx]) for idx in range(256)} + for (p0, p1), idx in self.merges.items(): + vocab[idx] = vocab[p0] + vocab[p1] + for special, idx in self.special_tokens.items(): + vocab[idx] = special.encode("utf-8") + return vocab + + def train(self, text, vocab_size, verbose=False): + """ + A number of optimizations are introduced: + - delete function call overhead by inlining functions + - modifying list of ids in place with .pop() instead of creating a new list + - collapse identical chunks to just the unique ones + - update counts more cleverly - only around the affected chunks + """ + assert vocab_size >= 256 + num_merges = vocab_size - 256 + + # split the text up into text chunks + text_chunks = re.findall(self.compiled_pattern, text) + + # many, many chunks are identical, so we can "collapse" them to just the unique ones + counts = Counter(text_chunks) + unique_chunks = [ch for ch, count in counts.items()] + chunk_counts = [count for ch, count in counts.items()] + + # input text preprocessing + ids = [list(ch.encode("utf-8")) for ch in unique_chunks] + # iteratively merge the most common pairs to create new tokens + merges = {} # (int, int) -> int + vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes + + # Initial count: build stats and position tracking + stats = defaultdict(int) + positions = defaultdict(set) # pair -> set of chunk indices that contain this pair + + for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)): + for pair in zip(chunk_ids, chunk_ids[1:]): + stats[pair] += count + positions[pair].add(chunk_idx) + + for i in range(num_merges): + if not stats: + break + + # find the pair with the highest count + pair = max(stats, key=stats.get) + # mint a new token: assign it the next available id + idx = 256 + i + + # Get chunks that contain this pair + affected_chunks = positions[pair] + + # Track count changes for incremental update + count_changes = defaultdict(int) + + # Replace all occurrences of pair in affected chunks only + for chunk_idx in affected_chunks: + chunk_ids = ids[chunk_idx] + chunk_count = chunk_counts[chunk_idx] + ix = 0 + while ix < len(chunk_ids) - 1: + if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]: + # Track what pairs are being removed/added + # Remove: (prev, A), (A, B), (B, next) + if ix > 0: + old_left = (chunk_ids[ix-1], chunk_ids[ix]) + count_changes[old_left] -= chunk_count + + # The merged pair disappears + count_changes[pair] -= chunk_count + + if ix + 2 < len(chunk_ids): + old_right = (chunk_ids[ix+1], chunk_ids[ix+2]) + count_changes[old_right] -= chunk_count + + # Apply the merge + chunk_ids[ix] = idx + chunk_ids.pop(ix+1) + + # Add: (prev, C), (C, next) + if ix > 0: + new_left = (chunk_ids[ix-1], chunk_ids[ix]) + count_changes[new_left] += chunk_count + + if ix + 1 < len(chunk_ids): + new_right = (chunk_ids[ix], chunk_ids[ix+1]) + count_changes[new_right] += chunk_count + else: + ix += 1 + + # Apply incremental changes to stats and positions + for changed_pair, delta in count_changes.items(): + if changed_pair == pair: + # The merged pair should disappear completely + continue + + stats[changed_pair] += delta + + # Update positions for changed pairs - only check affected chunks + for chunk_idx in affected_chunks: + chunk_ids = ids[chunk_idx] + contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair + for j in range(len(chunk_ids) - 1)) + if contains_pair: + positions[changed_pair].add(chunk_idx) + else: + positions[changed_pair].discard(chunk_idx) + + # Remove the merged pair completely + del stats[pair] + del positions[pair] + + # save the merge + merges[pair] = idx + vocab[idx] = vocab[pair[0]] + vocab[pair[1]] + + # save class variables + self.merges = merges # used in encode() + self.vocab = vocab # used in decode() + + def register_special_tokens(self, special_tokens): + # special_tokens is a dictionary of str -> int + # example: {"<|endoftext|>": 100257} + self.special_tokens = special_tokens + self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} + + def decode(self, ids): + # given ids (list of integers), return Python string + part_bytes = [] + for idx in ids: + if idx in self.vocab: + part_bytes.append(self.vocab[idx]) + elif idx in self.inverse_special_tokens: + part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) + else: + raise ValueError(f"invalid token id: {idx}") + text_bytes = b"".join(part_bytes) + text = text_bytes.decode("utf-8", errors="replace") + return text + + def _encode_chunk(self, text_bytes): + # return the token ids + # let's begin. first, convert all bytes to integers in range 0..255 + ids = list(text_bytes) + while len(ids) >= 2: + # find the pair with the lowest merge index + stats = get_stats(ids) + pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) + # subtle: if there are no more merges available, the key will + # result in an inf for every single pair, and the min will be + # just the first pair in the list, arbitrarily + # we can detect this terminating case by a membership check + if pair not in self.merges: + break # nothing else can be merged anymore + # otherwise let's merge the best pair (lowest merge index) + idx = self.merges[pair] + ids = fast_merge_inplace(ids, pair, idx) + return ids + + def encode_ordinary(self, text): + """Encoding that ignores any special tokens.""" + # split text into chunks of text by categories defined in regex pattern + text_chunks = re.findall(self.compiled_pattern, text) + # all chunks of text are encoded separately, then results are joined + ids = [] + for chunk in text_chunks: + chunk_bytes = chunk.encode("utf-8") # raw bytes + chunk_ids = self._encode_chunk(chunk_bytes) + ids.extend(chunk_ids) + return ids + +# ----------------------------------------------------------------------------- +# HuggingFace tokenizer +from tokenizers import Tokenizer as HFTokenizer +from tokenizers import pre_tokenizers, decoders, Regex +from tokenizers.models import BPE +from tokenizers.trainers import BpeTrainer + +class HuggingFaceTokenizer: + """Light wrapper around HuggingFace Tokenizer for some utilities""" + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + @classmethod + def train_from_iterator(cls, text_iterator, vocab_size): + # train from an iterator of text + # Configure the HuggingFace Tokenizer + tokenizer = HFTokenizer(BPE( + byte_fallback=True, # needed! + unk_token=None, + fuse_unk=False, + )) + # Normalizer: None + tokenizer.normalizer = None + # Pre-tokenizer: GPT-4 style + gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! + tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ + pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) + ]) + # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) + tokenizer.decoder = decoders.ByteLevel() + # Post-processor: None + tokenizer.post_processor = None + # Trainer: BPE + trainer = BpeTrainer( + vocab_size=vocab_size, + show_progress=True, + min_frequency=0, # no minimum frequency + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + special_tokens=[], # no special tokens + ) + # Kick off the training + tokenizer.train_from_iterator(text_iterator, trainer) + return cls(tokenizer) + + def encode_ordinary(self, text): + ids = self.tokenizer.encode(text, add_special_tokens=False).ids + return ids + +# ----------------------------------------------------------------------------- +# Test all of the above + +@pytest.fixture(scope="module") +def enwik8_path(): + """Fixture to download and cache enwik8 dataset.""" + import os + import zipfile + from nanochat.common import get_base_dir + base_dir = get_base_dir() + # download and unzip enwik8 to .cache directory + enwik8_url = "https://mattmahoney.net/dc/enwik8.zip" + enwik8_local_path = os.path.join(base_dir, "enwik8") + enwik8_local_path_zip = os.path.join(base_dir, "enwik8.zip") + if not os.path.exists(enwik8_local_path): + print(f"Downloading enwik8 to {enwik8_local_path_zip}") + import requests + response = requests.get(enwik8_url) + with open(enwik8_local_path_zip, "wb") as f: + f.write(response.content) + with zipfile.ZipFile(enwik8_local_path_zip, "r") as zip_ref: + zip_ref.extractall(base_dir) + print(f"Unzipped enwik8 to {enwik8_local_path}") + os.remove(enwik8_local_path_zip) + print(f"Removed {enwik8_local_path_zip}") + else: + print(f"Using existing enwik8 at {enwik8_local_path}") + return enwik8_local_path + + +@pytest.fixture(scope="module") +def enwik8_small(enwik8_path): + """Fixture providing 100KB of enwik8 for quick tests.""" + with open(enwik8_path, "r") as f: + return f.read(100_000) + +@pytest.fixture(scope="module") +def enwik8_large(enwik8_path): + """Fixture providing 10MB of enwik8 for performance tests.""" + with open(enwik8_path, "r") as f: + return f.read(10**7) + +def time_function(func, *args, **kwargs): + """Time a function call and return the result and elapsed time""" + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed = end_time - start_time + return result, elapsed + +def test_correctness(enwik8_small): + """Test that all tokenizer implementations produce the same results.""" + text = enwik8_small + encode_text = text + vocab_size = 256 + 20 # 20 merges + + # Train slow reference + print("\nTraining slow reference...") + slow_reference_tokenizer = RegexTokenizer() + ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size) + slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text) + print(f"Slow reference train time: {slow_reference_train_time:.4f}s") + print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s") + print(slow_reference_ids[:20]) + + if ambiguous_flag: + print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size") + print("The implementation could be correct but we might see different results below") + else: + print("βœ… Merge order is NOT ambiguous") + + # Train fast reference + print("\nTraining fast reference...") + fast_reference_tokenizer = FastRegexTokenizer() + _, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size) + fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text) + print(f"Fast reference train time: {fast_reference_train_time:.4f}s") + print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s") + print(fast_reference_ids[:20]) + + # Assert fast equals slow + assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference" + print("βœ… Fast == Slow") + + # Train HuggingFace + print("\nTraining HuggingFace...") + hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size) + hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text) + print(f"HuggingFace train time: {hf_train_time:.4f}s") + print(f"HuggingFace encode time: {hf_encode_time:.4f}s") + print(hf_ids[:20]) + + # HuggingFace has a different byte order, so we need custom matching + def custom_match(ids1, ids2): + perm = {} + for x, y in zip(ids1, ids2): + if x < 256: + if x in perm: + if perm[x] != y: + return False + perm[x] = y + if x >= 256 and x != y: + return False + return True + + assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference" + print("βœ… HuggingFace == Fast") + + # Finally use our own Rust implementation + print("\nTraining rustbpe...") + rustbpe_tokenizer = rustbpe.Tokenizer() + _, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size) + rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text) + print(f"RustBPE train time: {rustbpe_train_time:.4f}s") + print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s") + print(rustbpe_ids[:20]) + + assert rustbpe_ids == fast_reference_ids, "RustBPE should match fast reference" + print("βœ… RustBPE == Fast") + + # Now export rustbpe to tiktoken for more efficient inference + print("\nTesting tiktoken export...") + pattern = rustbpe_tokenizer.get_pattern() + mergeable_ranks_list = rustbpe_tokenizer.get_mergeable_ranks() + mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list} + enc = tiktoken.Encoding( + name="rustbpe", + pat_str=pattern, + mergeable_ranks=mergeable_ranks, + special_tokens={}, + ) + tiktoken_ids, tiktoken_encode_time = time_function(enc.encode, encode_text) + print(f"Tiktoken encode time: {tiktoken_encode_time:.4f}s") + print(tiktoken_ids[:20]) + + assert tiktoken_ids == rustbpe_ids, "Tiktoken should match RustBPE" + print("βœ… Tiktoken == RustBPE") + + +@pytest.mark.slow +def test_training_performance(enwik8_large): + """Use a bigger dataset and compare the training speed of the optimized tokenizers (Python, Rust, HuggingFace).""" + text = enwik8_large + vocab_size = 2048 + print(f"\nText length: {len(text)}") + + # Commenting out because it's just way too slow to matter + # Train optimized python version + # print("Training optimized python version...") + # optimized_python_tokenizer = FastRegexTokenizer() + # _, optimized_python_train_time = time_function(optimized_python_tokenizer.train, text, vocab_size) + # print(f"Optimized python train time: {optimized_python_train_time:.4f}s") + + # Train rustbpe + print("\nTraining rustbpe...") + rustbpe_tokenizer = rustbpe.Tokenizer() + _, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size) + print(f"RustBPE train time: {rustbpe_train_time:.4f}s") + assert rustbpe_train_time > 0, "Training should take some time" + + # Train HuggingFace + print("\nTraining HuggingFace...") + hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size) + print(f"HuggingFace train time: {hf_train_time:.4f}s") + assert hf_train_time > 0, "Training should take some time" + + # Print comparison + print(f"\nπŸ“Š Performance comparison:") + print(f" RustBPE: {rustbpe_train_time:.4f}s") + print(f" HuggingFace: {hf_train_time:.4f}s") + print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x") + +def test_interface(enwik8_small): + """Test the RustBPETokenizer interface for training, encoding, decoding, and serialization.""" + import tempfile + from nanochat.tokenizer import RustBPETokenizer + + # Simple train test + vocab_size = 300 + tok = RustBPETokenizer.train_from_iterator([enwik8_small], vocab_size) + assert tok.get_vocab_size() == vocab_size, f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}" + print(f"βœ… Trained tokenizer with vocab size {vocab_size}") + + # Encode/decode text + encode_text = "Hello world! How are you? πŸ™ƒ" + ids = tok.encode(encode_text) + print(f"\nInput text: {encode_text}") + print(f"IDs: {ids}") + decoded = tok.decode(ids) + print(f"Decoded: {decoded}") + assert decoded == encode_text, f"Decoded text doesn't match: {decoded} != {encode_text}" + print("βœ… Encode/decode test passed") + + # Encode batch test + ids_new = tok.encode([encode_text, encode_text]) + assert all(x == ids for x in ids_new), "Batch encoding should produce identical results" + print("βœ… Encode batch OK") + + # append/prepend functionality + ids_special = tok.encode(encode_text, prepend="<|bos|>", append="<|bos|>") + bos_token_id = tok.encode_special("<|bos|>") + assert ids_special == [bos_token_id] + ids + [bos_token_id], "Special tokens not correctly added" + print("βœ… append/prepend OK") + + # Save/load test through a temporary directory + with tempfile.TemporaryDirectory() as tmp_dir: + tok.save(tmp_dir) + tok_reloaded = RustBPETokenizer.from_directory(tmp_dir) + ids_reloaded = tok_reloaded.encode(encode_text) + assert ids_reloaded == ids, "Reloaded tokenizer should produce same results" + print("βœ… Save/load through temporary directory OK") diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..7636b81 --- /dev/null +++ b/uv.lock @@ -0,0 +1,2004 @@ +version = 1 +revision = 3 +requires-python = ">=3.10" +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'linux'", + "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'linux'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform != 'linux'", +] + +[[package]] +name = "aiohappyeyeballs" +version = "2.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/26/30/f84a107a9c4331c14b2b586036f40965c128aa4fee4dda5d3d51cb14ad54/aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558", size = 22760, upload-time = "2025-03-12T01:42:48.764Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/15/5bf3b99495fb160b63f95972b81750f18f7f4e02ad051373b669d17d44f2/aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8", size = 15265, upload-time = "2025-03-12T01:42:47.083Z" }, +] + +[[package]] +name = "aiohttp" +version = "3.12.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohappyeyeballs" }, + { name = "aiosignal" }, + { name = "async-timeout", marker = "python_full_version < '3.11'" }, + { name = "attrs" }, + { name = "frozenlist" }, + { name = "multidict" }, + { name = "propcache" }, + { name = "yarl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/e7/d92a237d8802ca88483906c388f7c201bbe96cd80a165ffd0ac2f6a8d59f/aiohttp-3.12.15.tar.gz", hash = "sha256:4fc61385e9c98d72fcdf47e6dd81833f47b2f77c114c29cd64a361be57a763a2", size = 7823716, upload-time = "2025-07-29T05:52:32.215Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/dc/ef9394bde9080128ad401ac7ede185267ed637df03b51f05d14d1c99ad67/aiohttp-3.12.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b6fc902bff74d9b1879ad55f5404153e2b33a82e72a95c89cec5eb6cc9e92fbc", size = 703921, upload-time = "2025-07-29T05:49:43.584Z" }, + { url = "https://files.pythonhosted.org/packages/8f/42/63fccfc3a7ed97eb6e1a71722396f409c46b60a0552d8a56d7aad74e0df5/aiohttp-3.12.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:098e92835b8119b54c693f2f88a1dec690e20798ca5f5fe5f0520245253ee0af", size = 480288, upload-time = "2025-07-29T05:49:47.851Z" }, + { url = "https://files.pythonhosted.org/packages/9c/a2/7b8a020549f66ea2a68129db6960a762d2393248f1994499f8ba9728bbed/aiohttp-3.12.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:40b3fee496a47c3b4a39a731954c06f0bd9bd3e8258c059a4beb76ac23f8e421", size = 468063, upload-time = "2025-07-29T05:49:49.789Z" }, + { url = "https://files.pythonhosted.org/packages/8f/f5/d11e088da9176e2ad8220338ae0000ed5429a15f3c9dfd983f39105399cd/aiohttp-3.12.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ce13fcfb0bb2f259fb42106cdc63fa5515fb85b7e87177267d89a771a660b79", size = 1650122, upload-time = "2025-07-29T05:49:51.874Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6b/b60ce2757e2faed3d70ed45dafee48cee7bfb878785a9423f7e883f0639c/aiohttp-3.12.15-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3beb14f053222b391bf9cf92ae82e0171067cc9c8f52453a0f1ec7c37df12a77", size = 1624176, upload-time = "2025-07-29T05:49:53.805Z" }, + { url = "https://files.pythonhosted.org/packages/dd/de/8c9fde2072a1b72c4fadecf4f7d4be7a85b1d9a4ab333d8245694057b4c6/aiohttp-3.12.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c39e87afe48aa3e814cac5f535bc6199180a53e38d3f51c5e2530f5aa4ec58c", size = 1696583, upload-time = "2025-07-29T05:49:55.338Z" }, + { url = "https://files.pythonhosted.org/packages/0c/ad/07f863ca3d895a1ad958a54006c6dafb4f9310f8c2fdb5f961b8529029d3/aiohttp-3.12.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5f1b4ce5bc528a6ee38dbf5f39bbf11dd127048726323b72b8e85769319ffc4", size = 1738896, upload-time = "2025-07-29T05:49:57.045Z" }, + { url = "https://files.pythonhosted.org/packages/20/43/2bd482ebe2b126533e8755a49b128ec4e58f1a3af56879a3abdb7b42c54f/aiohttp-3.12.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1004e67962efabbaf3f03b11b4c43b834081c9e3f9b32b16a7d97d4708a9abe6", size = 1643561, upload-time = "2025-07-29T05:49:58.762Z" }, + { url = "https://files.pythonhosted.org/packages/23/40/2fa9f514c4cf4cbae8d7911927f81a1901838baf5e09a8b2c299de1acfe5/aiohttp-3.12.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8faa08fcc2e411f7ab91d1541d9d597d3a90e9004180edb2072238c085eac8c2", size = 1583685, upload-time = "2025-07-29T05:50:00.375Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c3/94dc7357bc421f4fb978ca72a201a6c604ee90148f1181790c129396ceeb/aiohttp-3.12.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fe086edf38b2222328cdf89af0dde2439ee173b8ad7cb659b4e4c6f385b2be3d", size = 1627533, upload-time = "2025-07-29T05:50:02.306Z" }, + { url = "https://files.pythonhosted.org/packages/bf/3f/1f8911fe1844a07001e26593b5c255a685318943864b27b4e0267e840f95/aiohttp-3.12.15-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:79b26fe467219add81d5e47b4a4ba0f2394e8b7c7c3198ed36609f9ba161aecb", size = 1638319, upload-time = "2025-07-29T05:50:04.282Z" }, + { url = "https://files.pythonhosted.org/packages/4e/46/27bf57a99168c4e145ffee6b63d0458b9c66e58bb70687c23ad3d2f0bd17/aiohttp-3.12.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b761bac1192ef24e16706d761aefcb581438b34b13a2f069a6d343ec8fb693a5", size = 1613776, upload-time = "2025-07-29T05:50:05.863Z" }, + { url = "https://files.pythonhosted.org/packages/0f/7e/1d2d9061a574584bb4ad3dbdba0da90a27fdc795bc227def3a46186a8bc1/aiohttp-3.12.15-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e153e8adacfe2af562861b72f8bc47f8a5c08e010ac94eebbe33dc21d677cd5b", size = 1693359, upload-time = "2025-07-29T05:50:07.563Z" }, + { url = "https://files.pythonhosted.org/packages/08/98/bee429b52233c4a391980a5b3b196b060872a13eadd41c3a34be9b1469ed/aiohttp-3.12.15-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:fc49c4de44977aa8601a00edbf157e9a421f227aa7eb477d9e3df48343311065", size = 1716598, upload-time = "2025-07-29T05:50:09.33Z" }, + { url = "https://files.pythonhosted.org/packages/57/39/b0314c1ea774df3392751b686104a3938c63ece2b7ce0ba1ed7c0b4a934f/aiohttp-3.12.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2776c7ec89c54a47029940177e75c8c07c29c66f73464784971d6a81904ce9d1", size = 1644940, upload-time = "2025-07-29T05:50:11.334Z" }, + { url = "https://files.pythonhosted.org/packages/1b/83/3dacb8d3f8f512c8ca43e3fa8a68b20583bd25636ffa4e56ee841ffd79ae/aiohttp-3.12.15-cp310-cp310-win32.whl", hash = "sha256:2c7d81a277fa78b2203ab626ced1487420e8c11a8e373707ab72d189fcdad20a", size = 429239, upload-time = "2025-07-29T05:50:12.803Z" }, + { url = "https://files.pythonhosted.org/packages/eb/f9/470b5daba04d558c9673ca2034f28d067f3202a40e17804425f0c331c89f/aiohttp-3.12.15-cp310-cp310-win_amd64.whl", hash = "sha256:83603f881e11f0f710f8e2327817c82e79431ec976448839f3cd05d7afe8f830", size = 452297, upload-time = "2025-07-29T05:50:14.266Z" }, + { url = "https://files.pythonhosted.org/packages/20/19/9e86722ec8e835959bd97ce8c1efa78cf361fa4531fca372551abcc9cdd6/aiohttp-3.12.15-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d3ce17ce0220383a0f9ea07175eeaa6aa13ae5a41f30bc61d84df17f0e9b1117", size = 711246, upload-time = "2025-07-29T05:50:15.937Z" }, + { url = "https://files.pythonhosted.org/packages/71/f9/0a31fcb1a7d4629ac9d8f01f1cb9242e2f9943f47f5d03215af91c3c1a26/aiohttp-3.12.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:010cc9bbd06db80fe234d9003f67e97a10fe003bfbedb40da7d71c1008eda0fe", size = 483515, upload-time = "2025-07-29T05:50:17.442Z" }, + { url = "https://files.pythonhosted.org/packages/62/6c/94846f576f1d11df0c2e41d3001000527c0fdf63fce7e69b3927a731325d/aiohttp-3.12.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f9d7c55b41ed687b9d7165b17672340187f87a773c98236c987f08c858145a9", size = 471776, upload-time = "2025-07-29T05:50:19.568Z" }, + { url = "https://files.pythonhosted.org/packages/f8/6c/f766d0aaafcee0447fad0328da780d344489c042e25cd58fde566bf40aed/aiohttp-3.12.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc4fbc61bb3548d3b482f9ac7ddd0f18c67e4225aaa4e8552b9f1ac7e6bda9e5", size = 1741977, upload-time = "2025-07-29T05:50:21.665Z" }, + { url = "https://files.pythonhosted.org/packages/17/e5/fb779a05ba6ff44d7bc1e9d24c644e876bfff5abe5454f7b854cace1b9cc/aiohttp-3.12.15-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7fbc8a7c410bb3ad5d595bb7118147dfbb6449d862cc1125cf8867cb337e8728", size = 1690645, upload-time = "2025-07-29T05:50:23.333Z" }, + { url = "https://files.pythonhosted.org/packages/37/4e/a22e799c2035f5d6a4ad2cf8e7c1d1bd0923192871dd6e367dafb158b14c/aiohttp-3.12.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74dad41b3458dbb0511e760fb355bb0b6689e0630de8a22b1b62a98777136e16", size = 1789437, upload-time = "2025-07-29T05:50:25.007Z" }, + { url = "https://files.pythonhosted.org/packages/28/e5/55a33b991f6433569babb56018b2fb8fb9146424f8b3a0c8ecca80556762/aiohttp-3.12.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b6f0af863cf17e6222b1735a756d664159e58855da99cfe965134a3ff63b0b0", size = 1828482, upload-time = "2025-07-29T05:50:26.693Z" }, + { url = "https://files.pythonhosted.org/packages/c6/82/1ddf0ea4f2f3afe79dffed5e8a246737cff6cbe781887a6a170299e33204/aiohttp-3.12.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5b7fe4972d48a4da367043b8e023fb70a04d1490aa7d68800e465d1b97e493b", size = 1730944, upload-time = "2025-07-29T05:50:28.382Z" }, + { url = "https://files.pythonhosted.org/packages/1b/96/784c785674117b4cb3877522a177ba1b5e4db9ce0fd519430b5de76eec90/aiohttp-3.12.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6443cca89553b7a5485331bc9bedb2342b08d073fa10b8c7d1c60579c4a7b9bd", size = 1668020, upload-time = "2025-07-29T05:50:30.032Z" }, + { url = "https://files.pythonhosted.org/packages/12/8a/8b75f203ea7e5c21c0920d84dd24a5c0e971fe1e9b9ebbf29ae7e8e39790/aiohttp-3.12.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6c5f40ec615e5264f44b4282ee27628cea221fcad52f27405b80abb346d9f3f8", size = 1716292, upload-time = "2025-07-29T05:50:31.983Z" }, + { url = "https://files.pythonhosted.org/packages/47/0b/a1451543475bb6b86a5cfc27861e52b14085ae232896a2654ff1231c0992/aiohttp-3.12.15-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:2abbb216a1d3a2fe86dbd2edce20cdc5e9ad0be6378455b05ec7f77361b3ab50", size = 1711451, upload-time = "2025-07-29T05:50:33.989Z" }, + { url = "https://files.pythonhosted.org/packages/55/fd/793a23a197cc2f0d29188805cfc93aa613407f07e5f9da5cd1366afd9d7c/aiohttp-3.12.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:db71ce547012a5420a39c1b744d485cfb823564d01d5d20805977f5ea1345676", size = 1691634, upload-time = "2025-07-29T05:50:35.846Z" }, + { url = "https://files.pythonhosted.org/packages/ca/bf/23a335a6670b5f5dfc6d268328e55a22651b440fca341a64fccf1eada0c6/aiohttp-3.12.15-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ced339d7c9b5030abad5854aa5413a77565e5b6e6248ff927d3e174baf3badf7", size = 1785238, upload-time = "2025-07-29T05:50:37.597Z" }, + { url = "https://files.pythonhosted.org/packages/57/4f/ed60a591839a9d85d40694aba5cef86dde9ee51ce6cca0bb30d6eb1581e7/aiohttp-3.12.15-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:7c7dd29c7b5bda137464dc9bfc738d7ceea46ff70309859ffde8c022e9b08ba7", size = 1805701, upload-time = "2025-07-29T05:50:39.591Z" }, + { url = "https://files.pythonhosted.org/packages/85/e0/444747a9455c5de188c0f4a0173ee701e2e325d4b2550e9af84abb20cdba/aiohttp-3.12.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:421da6fd326460517873274875c6c5a18ff225b40da2616083c5a34a7570b685", size = 1718758, upload-time = "2025-07-29T05:50:41.292Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/1006278d1ffd13a698e5dd4bfa01e5878f6bddefc296c8b62649753ff249/aiohttp-3.12.15-cp311-cp311-win32.whl", hash = "sha256:4420cf9d179ec8dfe4be10e7d0fe47d6d606485512ea2265b0d8c5113372771b", size = 428868, upload-time = "2025-07-29T05:50:43.063Z" }, + { url = "https://files.pythonhosted.org/packages/10/97/ad2b18700708452400278039272032170246a1bf8ec5d832772372c71f1a/aiohttp-3.12.15-cp311-cp311-win_amd64.whl", hash = "sha256:edd533a07da85baa4b423ee8839e3e91681c7bfa19b04260a469ee94b778bf6d", size = 453273, upload-time = "2025-07-29T05:50:44.613Z" }, + { url = "https://files.pythonhosted.org/packages/63/97/77cb2450d9b35f517d6cf506256bf4f5bda3f93a66b4ad64ba7fc917899c/aiohttp-3.12.15-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:802d3868f5776e28f7bf69d349c26fc0efadb81676d0afa88ed00d98a26340b7", size = 702333, upload-time = "2025-07-29T05:50:46.507Z" }, + { url = "https://files.pythonhosted.org/packages/83/6d/0544e6b08b748682c30b9f65640d006e51f90763b41d7c546693bc22900d/aiohttp-3.12.15-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f2800614cd560287be05e33a679638e586a2d7401f4ddf99e304d98878c29444", size = 476948, upload-time = "2025-07-29T05:50:48.067Z" }, + { url = "https://files.pythonhosted.org/packages/3a/1d/c8c40e611e5094330284b1aea8a4b02ca0858f8458614fa35754cab42b9c/aiohttp-3.12.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8466151554b593909d30a0a125d638b4e5f3836e5aecde85b66b80ded1cb5b0d", size = 469787, upload-time = "2025-07-29T05:50:49.669Z" }, + { url = "https://files.pythonhosted.org/packages/38/7d/b76438e70319796bfff717f325d97ce2e9310f752a267bfdf5192ac6082b/aiohttp-3.12.15-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e5a495cb1be69dae4b08f35a6c4579c539e9b5706f606632102c0f855bcba7c", size = 1716590, upload-time = "2025-07-29T05:50:51.368Z" }, + { url = "https://files.pythonhosted.org/packages/79/b1/60370d70cdf8b269ee1444b390cbd72ce514f0d1cd1a715821c784d272c9/aiohttp-3.12.15-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6404dfc8cdde35c69aaa489bb3542fb86ef215fc70277c892be8af540e5e21c0", size = 1699241, upload-time = "2025-07-29T05:50:53.628Z" }, + { url = "https://files.pythonhosted.org/packages/a3/2b/4968a7b8792437ebc12186db31523f541943e99bda8f30335c482bea6879/aiohttp-3.12.15-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ead1c00f8521a5c9070fcb88f02967b1d8a0544e6d85c253f6968b785e1a2ab", size = 1754335, upload-time = "2025-07-29T05:50:55.394Z" }, + { url = "https://files.pythonhosted.org/packages/fb/c1/49524ed553f9a0bec1a11fac09e790f49ff669bcd14164f9fab608831c4d/aiohttp-3.12.15-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6990ef617f14450bc6b34941dba4f12d5613cbf4e33805932f853fbd1cf18bfb", size = 1800491, upload-time = "2025-07-29T05:50:57.202Z" }, + { url = "https://files.pythonhosted.org/packages/de/5e/3bf5acea47a96a28c121b167f5ef659cf71208b19e52a88cdfa5c37f1fcc/aiohttp-3.12.15-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd736ed420f4db2b8148b52b46b88ed038d0354255f9a73196b7bbce3ea97545", size = 1719929, upload-time = "2025-07-29T05:50:59.192Z" }, + { url = "https://files.pythonhosted.org/packages/39/94/8ae30b806835bcd1cba799ba35347dee6961a11bd507db634516210e91d8/aiohttp-3.12.15-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c5092ce14361a73086b90c6efb3948ffa5be2f5b6fbcf52e8d8c8b8848bb97c", size = 1635733, upload-time = "2025-07-29T05:51:01.394Z" }, + { url = "https://files.pythonhosted.org/packages/7a/46/06cdef71dd03acd9da7f51ab3a9107318aee12ad38d273f654e4f981583a/aiohttp-3.12.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:aaa2234bb60c4dbf82893e934d8ee8dea30446f0647e024074237a56a08c01bd", size = 1696790, upload-time = "2025-07-29T05:51:03.657Z" }, + { url = "https://files.pythonhosted.org/packages/02/90/6b4cfaaf92ed98d0ec4d173e78b99b4b1a7551250be8937d9d67ecb356b4/aiohttp-3.12.15-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:6d86a2fbdd14192e2f234a92d3b494dd4457e683ba07e5905a0b3ee25389ac9f", size = 1718245, upload-time = "2025-07-29T05:51:05.911Z" }, + { url = "https://files.pythonhosted.org/packages/2e/e6/2593751670fa06f080a846f37f112cbe6f873ba510d070136a6ed46117c6/aiohttp-3.12.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a041e7e2612041a6ddf1c6a33b883be6a421247c7afd47e885969ee4cc58bd8d", size = 1658899, upload-time = "2025-07-29T05:51:07.753Z" }, + { url = "https://files.pythonhosted.org/packages/8f/28/c15bacbdb8b8eb5bf39b10680d129ea7410b859e379b03190f02fa104ffd/aiohttp-3.12.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5015082477abeafad7203757ae44299a610e89ee82a1503e3d4184e6bafdd519", size = 1738459, upload-time = "2025-07-29T05:51:09.56Z" }, + { url = "https://files.pythonhosted.org/packages/00/de/c269cbc4faa01fb10f143b1670633a8ddd5b2e1ffd0548f7aa49cb5c70e2/aiohttp-3.12.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:56822ff5ddfd1b745534e658faba944012346184fbfe732e0d6134b744516eea", size = 1766434, upload-time = "2025-07-29T05:51:11.423Z" }, + { url = "https://files.pythonhosted.org/packages/52/b0/4ff3abd81aa7d929b27d2e1403722a65fc87b763e3a97b3a2a494bfc63bc/aiohttp-3.12.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b2acbbfff69019d9014508c4ba0401822e8bae5a5fdc3b6814285b71231b60f3", size = 1726045, upload-time = "2025-07-29T05:51:13.689Z" }, + { url = "https://files.pythonhosted.org/packages/71/16/949225a6a2dd6efcbd855fbd90cf476052e648fb011aa538e3b15b89a57a/aiohttp-3.12.15-cp312-cp312-win32.whl", hash = "sha256:d849b0901b50f2185874b9a232f38e26b9b3d4810095a7572eacea939132d4e1", size = 423591, upload-time = "2025-07-29T05:51:15.452Z" }, + { url = "https://files.pythonhosted.org/packages/2b/d8/fa65d2a349fe938b76d309db1a56a75c4fb8cc7b17a398b698488a939903/aiohttp-3.12.15-cp312-cp312-win_amd64.whl", hash = "sha256:b390ef5f62bb508a9d67cb3bba9b8356e23b3996da7062f1a57ce1a79d2b3d34", size = 450266, upload-time = "2025-07-29T05:51:17.239Z" }, + { url = "https://files.pythonhosted.org/packages/f2/33/918091abcf102e39d15aba2476ad9e7bd35ddb190dcdd43a854000d3da0d/aiohttp-3.12.15-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:9f922ffd05034d439dde1c77a20461cf4a1b0831e6caa26151fe7aa8aaebc315", size = 696741, upload-time = "2025-07-29T05:51:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2a/7495a81e39a998e400f3ecdd44a62107254803d1681d9189be5c2e4530cd/aiohttp-3.12.15-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ee8a8ac39ce45f3e55663891d4b1d15598c157b4d494a4613e704c8b43112cd", size = 474407, upload-time = "2025-07-29T05:51:21.165Z" }, + { url = "https://files.pythonhosted.org/packages/49/fc/a9576ab4be2dcbd0f73ee8675d16c707cfc12d5ee80ccf4015ba543480c9/aiohttp-3.12.15-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3eae49032c29d356b94eee45a3f39fdf4b0814b397638c2f718e96cfadf4c4e4", size = 466703, upload-time = "2025-07-29T05:51:22.948Z" }, + { url = "https://files.pythonhosted.org/packages/09/2f/d4bcc8448cf536b2b54eed48f19682031ad182faa3a3fee54ebe5b156387/aiohttp-3.12.15-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b97752ff12cc12f46a9b20327104448042fce5c33a624f88c18f66f9368091c7", size = 1705532, upload-time = "2025-07-29T05:51:25.211Z" }, + { url = "https://files.pythonhosted.org/packages/f1/f3/59406396083f8b489261e3c011aa8aee9df360a96ac8fa5c2e7e1b8f0466/aiohttp-3.12.15-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:894261472691d6fe76ebb7fcf2e5870a2ac284c7406ddc95823c8598a1390f0d", size = 1686794, upload-time = "2025-07-29T05:51:27.145Z" }, + { url = "https://files.pythonhosted.org/packages/dc/71/164d194993a8d114ee5656c3b7ae9c12ceee7040d076bf7b32fb98a8c5c6/aiohttp-3.12.15-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5fa5d9eb82ce98959fc1031c28198b431b4d9396894f385cb63f1e2f3f20ca6b", size = 1738865, upload-time = "2025-07-29T05:51:29.366Z" }, + { url = "https://files.pythonhosted.org/packages/1c/00/d198461b699188a93ead39cb458554d9f0f69879b95078dce416d3209b54/aiohttp-3.12.15-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0fa751efb11a541f57db59c1dd821bec09031e01452b2b6217319b3a1f34f3d", size = 1788238, upload-time = "2025-07-29T05:51:31.285Z" }, + { url = "https://files.pythonhosted.org/packages/85/b8/9e7175e1fa0ac8e56baa83bf3c214823ce250d0028955dfb23f43d5e61fd/aiohttp-3.12.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5346b93e62ab51ee2a9d68e8f73c7cf96ffb73568a23e683f931e52450e4148d", size = 1710566, upload-time = "2025-07-29T05:51:33.219Z" }, + { url = "https://files.pythonhosted.org/packages/59/e4/16a8eac9df39b48ae102ec030fa9f726d3570732e46ba0c592aeeb507b93/aiohttp-3.12.15-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:049ec0360f939cd164ecbfd2873eaa432613d5e77d6b04535e3d1fbae5a9e645", size = 1624270, upload-time = "2025-07-29T05:51:35.195Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f8/cd84dee7b6ace0740908fd0af170f9fab50c2a41ccbc3806aabcb1050141/aiohttp-3.12.15-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b52dcf013b57464b6d1e51b627adfd69a8053e84b7103a7cd49c030f9ca44461", size = 1677294, upload-time = "2025-07-29T05:51:37.215Z" }, + { url = "https://files.pythonhosted.org/packages/ce/42/d0f1f85e50d401eccd12bf85c46ba84f947a84839c8a1c2c5f6e8ab1eb50/aiohttp-3.12.15-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9b2af240143dd2765e0fb661fd0361a1b469cab235039ea57663cda087250ea9", size = 1708958, upload-time = "2025-07-29T05:51:39.328Z" }, + { url = "https://files.pythonhosted.org/packages/d5/6b/f6fa6c5790fb602538483aa5a1b86fcbad66244997e5230d88f9412ef24c/aiohttp-3.12.15-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ac77f709a2cde2cc71257ab2d8c74dd157c67a0558a0d2799d5d571b4c63d44d", size = 1651553, upload-time = "2025-07-29T05:51:41.356Z" }, + { url = "https://files.pythonhosted.org/packages/04/36/a6d36ad545fa12e61d11d1932eef273928b0495e6a576eb2af04297fdd3c/aiohttp-3.12.15-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:47f6b962246f0a774fbd3b6b7be25d59b06fdb2f164cf2513097998fc6a29693", size = 1727688, upload-time = "2025-07-29T05:51:43.452Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c8/f195e5e06608a97a4e52c5d41c7927301bf757a8e8bb5bbf8cef6c314961/aiohttp-3.12.15-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:760fb7db442f284996e39cf9915a94492e1896baac44f06ae551974907922b64", size = 1761157, upload-time = "2025-07-29T05:51:45.643Z" }, + { url = "https://files.pythonhosted.org/packages/05/6a/ea199e61b67f25ba688d3ce93f63b49b0a4e3b3d380f03971b4646412fc6/aiohttp-3.12.15-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad702e57dc385cae679c39d318def49aef754455f237499d5b99bea4ef582e51", size = 1710050, upload-time = "2025-07-29T05:51:48.203Z" }, + { url = "https://files.pythonhosted.org/packages/b4/2e/ffeb7f6256b33635c29dbed29a22a723ff2dd7401fff42ea60cf2060abfb/aiohttp-3.12.15-cp313-cp313-win32.whl", hash = "sha256:f813c3e9032331024de2eb2e32a88d86afb69291fbc37a3a3ae81cc9917fb3d0", size = 422647, upload-time = "2025-07-29T05:51:50.718Z" }, + { url = "https://files.pythonhosted.org/packages/1b/8e/78ee35774201f38d5e1ba079c9958f7629b1fd079459aea9467441dbfbf5/aiohttp-3.12.15-cp313-cp313-win_amd64.whl", hash = "sha256:1a649001580bdb37c6fdb1bebbd7e3bc688e8ec2b5c6f52edbb664662b17dc84", size = 449067, upload-time = "2025-07-29T05:51:52.549Z" }, +] + +[[package]] +name = "aiosignal" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "frozenlist" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "anyio" +version = "4.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/b4/636b3b65173d3ce9a38ef5f0522789614e590dab6a8d505340a4efe4c567/anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6", size = 213252, upload-time = "2025-08-04T08:54:26.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/12/e5e0282d673bb9746bacfb6e2dba8719989d3660cdb2ea79aee9a9651afb/anyio-4.10.0-py3-none-any.whl", hash = "sha256:60e474ac86736bbfd6f210f7a61218939c318f43f9972497381f1c5e930ed3d1", size = 107213, upload-time = "2025-08-04T08:54:24.882Z" }, +] + +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274, upload-time = "2024-11-06T16:41:39.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, +] + +[[package]] +name = "attrs" +version = "25.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/b0/1367933a8532ee6ff8d63537de4f1177af4bff9f3e829baf7331f595bb24/attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b", size = 812032, upload-time = "2025-03-13T11:10:22.779Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, +] + +[[package]] +name = "certifi" +version = "2025.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/67/960ebe6bf230a96cda2e0abcf73af550ec4f090005363542f0765df162e0/certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407", size = 162386, upload-time = "2025-08-03T03:07:47.08Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216, upload-time = "2025-08-03T03:07:45.777Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/2d/5fd176ceb9b2fc619e63405525573493ca23441330fcdaee6bef9460e924/charset_normalizer-3.4.3.tar.gz", hash = "sha256:6fce4b8500244f6fcb71465d4a4930d132ba9ab8e71a7859e6a5d59851068d14", size = 122371, upload-time = "2025-08-09T07:57:28.46Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d6/98/f3b8013223728a99b908c9344da3aa04ee6e3fa235f19409033eda92fb78/charset_normalizer-3.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb7f67a1bfa6e40b438170ebdc8158b78dc465a5a67b6dde178a46987b244a72", size = 207695, upload-time = "2025-08-09T07:55:36.452Z" }, + { url = "https://files.pythonhosted.org/packages/21/40/5188be1e3118c82dcb7c2a5ba101b783822cfb413a0268ed3be0468532de/charset_normalizer-3.4.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc9370a2da1ac13f0153780040f465839e6cccb4a1e44810124b4e22483c93fe", size = 147153, upload-time = "2025-08-09T07:55:38.467Z" }, + { url = "https://files.pythonhosted.org/packages/37/60/5d0d74bc1e1380f0b72c327948d9c2aca14b46a9efd87604e724260f384c/charset_normalizer-3.4.3-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:07a0eae9e2787b586e129fdcbe1af6997f8d0e5abaa0bc98c0e20e124d67e601", size = 160428, upload-time = "2025-08-09T07:55:40.072Z" }, + { url = "https://files.pythonhosted.org/packages/85/9a/d891f63722d9158688de58d050c59dc3da560ea7f04f4c53e769de5140f5/charset_normalizer-3.4.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:74d77e25adda8581ffc1c720f1c81ca082921329452eba58b16233ab1842141c", size = 157627, upload-time = "2025-08-09T07:55:41.706Z" }, + { url = "https://files.pythonhosted.org/packages/65/1a/7425c952944a6521a9cfa7e675343f83fd82085b8af2b1373a2409c683dc/charset_normalizer-3.4.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d0e909868420b7049dafd3a31d45125b31143eec59235311fc4c57ea26a4acd2", size = 152388, upload-time = "2025-08-09T07:55:43.262Z" }, + { url = "https://files.pythonhosted.org/packages/f0/c9/a2c9c2a355a8594ce2446085e2ec97fd44d323c684ff32042e2a6b718e1d/charset_normalizer-3.4.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c6f162aabe9a91a309510d74eeb6507fab5fff92337a15acbe77753d88d9dcf0", size = 150077, upload-time = "2025-08-09T07:55:44.903Z" }, + { url = "https://files.pythonhosted.org/packages/3b/38/20a1f44e4851aa1c9105d6e7110c9d020e093dfa5836d712a5f074a12bf7/charset_normalizer-3.4.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4ca4c094de7771a98d7fbd67d9e5dbf1eb73efa4f744a730437d8a3a5cf994f0", size = 161631, upload-time = "2025-08-09T07:55:46.346Z" }, + { url = "https://files.pythonhosted.org/packages/a4/fa/384d2c0f57edad03d7bec3ebefb462090d8905b4ff5a2d2525f3bb711fac/charset_normalizer-3.4.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:02425242e96bcf29a49711b0ca9f37e451da7c70562bc10e8ed992a5a7a25cc0", size = 159210, upload-time = "2025-08-09T07:55:47.539Z" }, + { url = "https://files.pythonhosted.org/packages/33/9e/eca49d35867ca2db336b6ca27617deed4653b97ebf45dfc21311ce473c37/charset_normalizer-3.4.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:78deba4d8f9590fe4dae384aeff04082510a709957e968753ff3c48399f6f92a", size = 153739, upload-time = "2025-08-09T07:55:48.744Z" }, + { url = "https://files.pythonhosted.org/packages/2a/91/26c3036e62dfe8de8061182d33be5025e2424002125c9500faff74a6735e/charset_normalizer-3.4.3-cp310-cp310-win32.whl", hash = "sha256:d79c198e27580c8e958906f803e63cddb77653731be08851c7df0b1a14a8fc0f", size = 99825, upload-time = "2025-08-09T07:55:50.305Z" }, + { url = "https://files.pythonhosted.org/packages/e2/c6/f05db471f81af1fa01839d44ae2a8bfeec8d2a8b4590f16c4e7393afd323/charset_normalizer-3.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:c6e490913a46fa054e03699c70019ab869e990270597018cef1d8562132c2669", size = 107452, upload-time = "2025-08-09T07:55:51.461Z" }, + { url = "https://files.pythonhosted.org/packages/7f/b5/991245018615474a60965a7c9cd2b4efbaabd16d582a5547c47ee1c7730b/charset_normalizer-3.4.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b256ee2e749283ef3ddcff51a675ff43798d92d746d1a6e4631bf8c707d22d0b", size = 204483, upload-time = "2025-08-09T07:55:53.12Z" }, + { url = "https://files.pythonhosted.org/packages/c7/2a/ae245c41c06299ec18262825c1569c5d3298fc920e4ddf56ab011b417efd/charset_normalizer-3.4.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:13faeacfe61784e2559e690fc53fa4c5ae97c6fcedb8eb6fb8d0a15b475d2c64", size = 145520, upload-time = "2025-08-09T07:55:54.712Z" }, + { url = "https://files.pythonhosted.org/packages/3a/a4/b3b6c76e7a635748c4421d2b92c7b8f90a432f98bda5082049af37ffc8e3/charset_normalizer-3.4.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:00237675befef519d9af72169d8604a067d92755e84fe76492fef5441db05b91", size = 158876, upload-time = "2025-08-09T07:55:56.024Z" }, + { url = "https://files.pythonhosted.org/packages/e2/e6/63bb0e10f90a8243c5def74b5b105b3bbbfb3e7bb753915fe333fb0c11ea/charset_normalizer-3.4.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:585f3b2a80fbd26b048a0be90c5aae8f06605d3c92615911c3a2b03a8a3b796f", size = 156083, upload-time = "2025-08-09T07:55:57.582Z" }, + { url = "https://files.pythonhosted.org/packages/87/df/b7737ff046c974b183ea9aa111b74185ac8c3a326c6262d413bd5a1b8c69/charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e78314bdc32fa80696f72fa16dc61168fda4d6a0c014e0380f9d02f0e5d8a07", size = 150295, upload-time = "2025-08-09T07:55:59.147Z" }, + { url = "https://files.pythonhosted.org/packages/61/f1/190d9977e0084d3f1dc169acd060d479bbbc71b90bf3e7bf7b9927dec3eb/charset_normalizer-3.4.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:96b2b3d1a83ad55310de8c7b4a2d04d9277d5591f40761274856635acc5fcb30", size = 148379, upload-time = "2025-08-09T07:56:00.364Z" }, + { url = "https://files.pythonhosted.org/packages/4c/92/27dbe365d34c68cfe0ca76f1edd70e8705d82b378cb54ebbaeabc2e3029d/charset_normalizer-3.4.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:939578d9d8fd4299220161fdd76e86c6a251987476f5243e8864a7844476ba14", size = 160018, upload-time = "2025-08-09T07:56:01.678Z" }, + { url = "https://files.pythonhosted.org/packages/99/04/baae2a1ea1893a01635d475b9261c889a18fd48393634b6270827869fa34/charset_normalizer-3.4.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:fd10de089bcdcd1be95a2f73dbe6254798ec1bda9f450d5828c96f93e2536b9c", size = 157430, upload-time = "2025-08-09T07:56:02.87Z" }, + { url = "https://files.pythonhosted.org/packages/2f/36/77da9c6a328c54d17b960c89eccacfab8271fdaaa228305330915b88afa9/charset_normalizer-3.4.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1e8ac75d72fa3775e0b7cb7e4629cec13b7514d928d15ef8ea06bca03ef01cae", size = 151600, upload-time = "2025-08-09T07:56:04.089Z" }, + { url = "https://files.pythonhosted.org/packages/64/d4/9eb4ff2c167edbbf08cdd28e19078bf195762e9bd63371689cab5ecd3d0d/charset_normalizer-3.4.3-cp311-cp311-win32.whl", hash = "sha256:6cf8fd4c04756b6b60146d98cd8a77d0cdae0e1ca20329da2ac85eed779b6849", size = 99616, upload-time = "2025-08-09T07:56:05.658Z" }, + { url = "https://files.pythonhosted.org/packages/f4/9c/996a4a028222e7761a96634d1820de8a744ff4327a00ada9c8942033089b/charset_normalizer-3.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:31a9a6f775f9bcd865d88ee350f0ffb0e25936a7f930ca98995c05abf1faf21c", size = 107108, upload-time = "2025-08-09T07:56:07.176Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5e/14c94999e418d9b87682734589404a25854d5f5d0408df68bc15b6ff54bb/charset_normalizer-3.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e28e334d3ff134e88989d90ba04b47d84382a828c061d0d1027b1b12a62b39b1", size = 205655, upload-time = "2025-08-09T07:56:08.475Z" }, + { url = "https://files.pythonhosted.org/packages/7d/a8/c6ec5d389672521f644505a257f50544c074cf5fc292d5390331cd6fc9c3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0cacf8f7297b0c4fcb74227692ca46b4a5852f8f4f24b3c766dd94a1075c4884", size = 146223, upload-time = "2025-08-09T07:56:09.708Z" }, + { url = "https://files.pythonhosted.org/packages/fc/eb/a2ffb08547f4e1e5415fb69eb7db25932c52a52bed371429648db4d84fb1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c6fd51128a41297f5409deab284fecbe5305ebd7e5a1f959bee1c054622b7018", size = 159366, upload-time = "2025-08-09T07:56:11.326Z" }, + { url = "https://files.pythonhosted.org/packages/82/10/0fd19f20c624b278dddaf83b8464dcddc2456cb4b02bb902a6da126b87a1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3cfb2aad70f2c6debfbcb717f23b7eb55febc0bb23dcffc0f076009da10c6392", size = 157104, upload-time = "2025-08-09T07:56:13.014Z" }, + { url = "https://files.pythonhosted.org/packages/16/ab/0233c3231af734f5dfcf0844aa9582d5a1466c985bbed6cedab85af9bfe3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1606f4a55c0fd363d754049cdf400175ee96c992b1f8018b993941f221221c5f", size = 151830, upload-time = "2025-08-09T07:56:14.428Z" }, + { url = "https://files.pythonhosted.org/packages/ae/02/e29e22b4e02839a0e4a06557b1999d0a47db3567e82989b5bb21f3fbbd9f/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:027b776c26d38b7f15b26a5da1044f376455fb3766df8fc38563b4efbc515154", size = 148854, upload-time = "2025-08-09T07:56:16.051Z" }, + { url = "https://files.pythonhosted.org/packages/05/6b/e2539a0a4be302b481e8cafb5af8792da8093b486885a1ae4d15d452bcec/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:42e5088973e56e31e4fa58eb6bd709e42fc03799c11c42929592889a2e54c491", size = 160670, upload-time = "2025-08-09T07:56:17.314Z" }, + { url = "https://files.pythonhosted.org/packages/31/e7/883ee5676a2ef217a40ce0bffcc3d0dfbf9e64cbcfbdf822c52981c3304b/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc34f233c9e71701040d772aa7490318673aa7164a0efe3172b2981218c26d93", size = 158501, upload-time = "2025-08-09T07:56:18.641Z" }, + { url = "https://files.pythonhosted.org/packages/c1/35/6525b21aa0db614cf8b5792d232021dca3df7f90a1944db934efa5d20bb1/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:320e8e66157cc4e247d9ddca8e21f427efc7a04bbd0ac8a9faf56583fa543f9f", size = 153173, upload-time = "2025-08-09T07:56:20.289Z" }, + { url = "https://files.pythonhosted.org/packages/50/ee/f4704bad8201de513fdc8aac1cabc87e38c5818c93857140e06e772b5892/charset_normalizer-3.4.3-cp312-cp312-win32.whl", hash = "sha256:fb6fecfd65564f208cbf0fba07f107fb661bcd1a7c389edbced3f7a493f70e37", size = 99822, upload-time = "2025-08-09T07:56:21.551Z" }, + { url = "https://files.pythonhosted.org/packages/39/f5/3b3836ca6064d0992c58c7561c6b6eee1b3892e9665d650c803bd5614522/charset_normalizer-3.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:86df271bf921c2ee3818f0522e9a5b8092ca2ad8b065ece5d7d9d0e9f4849bcc", size = 107543, upload-time = "2025-08-09T07:56:23.115Z" }, + { url = "https://files.pythonhosted.org/packages/65/ca/2135ac97709b400c7654b4b764daf5c5567c2da45a30cdd20f9eefe2d658/charset_normalizer-3.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:14c2a87c65b351109f6abfc424cab3927b3bdece6f706e4d12faaf3d52ee5efe", size = 205326, upload-time = "2025-08-09T07:56:24.721Z" }, + { url = "https://files.pythonhosted.org/packages/71/11/98a04c3c97dd34e49c7d247083af03645ca3730809a5509443f3c37f7c99/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:41d1fc408ff5fdfb910200ec0e74abc40387bccb3252f3f27c0676731df2b2c8", size = 146008, upload-time = "2025-08-09T07:56:26.004Z" }, + { url = "https://files.pythonhosted.org/packages/60/f5/4659a4cb3c4ec146bec80c32d8bb16033752574c20b1252ee842a95d1a1e/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1bb60174149316da1c35fa5233681f7c0f9f514509b8e399ab70fea5f17e45c9", size = 159196, upload-time = "2025-08-09T07:56:27.25Z" }, + { url = "https://files.pythonhosted.org/packages/86/9e/f552f7a00611f168b9a5865a1414179b2c6de8235a4fa40189f6f79a1753/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30d006f98569de3459c2fc1f2acde170b7b2bd265dc1943e87e1a4efe1b67c31", size = 156819, upload-time = "2025-08-09T07:56:28.515Z" }, + { url = "https://files.pythonhosted.org/packages/7e/95/42aa2156235cbc8fa61208aded06ef46111c4d3f0de233107b3f38631803/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:416175faf02e4b0810f1f38bcb54682878a4af94059a1cd63b8747244420801f", size = 151350, upload-time = "2025-08-09T07:56:29.716Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a9/3865b02c56f300a6f94fc631ef54f0a8a29da74fb45a773dfd3dcd380af7/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6aab0f181c486f973bc7262a97f5aca3ee7e1437011ef0c2ec04b5a11d16c927", size = 148644, upload-time = "2025-08-09T07:56:30.984Z" }, + { url = "https://files.pythonhosted.org/packages/77/d9/cbcf1a2a5c7d7856f11e7ac2d782aec12bdfea60d104e60e0aa1c97849dc/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabf8315679312cfa71302f9bd509ded4f2f263fb5b765cf1433b39106c3cc9", size = 160468, upload-time = "2025-08-09T07:56:32.252Z" }, + { url = "https://files.pythonhosted.org/packages/f6/42/6f45efee8697b89fda4d50580f292b8f7f9306cb2971d4b53f8914e4d890/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:bd28b817ea8c70215401f657edef3a8aa83c29d447fb0b622c35403780ba11d5", size = 158187, upload-time = "2025-08-09T07:56:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/70/99/f1c3bdcfaa9c45b3ce96f70b14f070411366fa19549c1d4832c935d8e2c3/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:18343b2d246dc6761a249ba1fb13f9ee9a2bcd95decc767319506056ea4ad4dc", size = 152699, upload-time = "2025-08-09T07:56:34.739Z" }, + { url = "https://files.pythonhosted.org/packages/a3/ad/b0081f2f99a4b194bcbb1934ef3b12aa4d9702ced80a37026b7607c72e58/charset_normalizer-3.4.3-cp313-cp313-win32.whl", hash = "sha256:6fb70de56f1859a3f71261cbe41005f56a7842cc348d3aeb26237560bfa5e0ce", size = 99580, upload-time = "2025-08-09T07:56:35.981Z" }, + { url = "https://files.pythonhosted.org/packages/9a/8f/ae790790c7b64f925e5c953b924aaa42a243fb778fed9e41f147b2a5715a/charset_normalizer-3.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:cf1ebb7d78e1ad8ec2a8c4732c7be2e736f6e5123a4146c5b89c9d1f585f8cef", size = 107366, upload-time = "2025-08-09T07:56:37.339Z" }, + { url = "https://files.pythonhosted.org/packages/8e/91/b5a06ad970ddc7a0e513112d40113e834638f4ca1120eb727a249fb2715e/charset_normalizer-3.4.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3cd35b7e8aedeb9e34c41385fda4f73ba609e561faedfae0a9e75e44ac558a15", size = 204342, upload-time = "2025-08-09T07:56:38.687Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ec/1edc30a377f0a02689342f214455c3f6c2fbedd896a1d2f856c002fc3062/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b89bc04de1d83006373429975f8ef9e7932534b8cc9ca582e4db7d20d91816db", size = 145995, upload-time = "2025-08-09T07:56:40.048Z" }, + { url = "https://files.pythonhosted.org/packages/17/e5/5e67ab85e6d22b04641acb5399c8684f4d37caf7558a53859f0283a650e9/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2001a39612b241dae17b4687898843f254f8748b796a2e16f1051a17078d991d", size = 158640, upload-time = "2025-08-09T07:56:41.311Z" }, + { url = "https://files.pythonhosted.org/packages/f1/e5/38421987f6c697ee3722981289d554957c4be652f963d71c5e46a262e135/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8dcfc373f888e4fb39a7bc57e93e3b845e7f462dacc008d9749568b1c4ece096", size = 156636, upload-time = "2025-08-09T07:56:43.195Z" }, + { url = "https://files.pythonhosted.org/packages/a0/e4/5a075de8daa3ec0745a9a3b54467e0c2967daaaf2cec04c845f73493e9a1/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18b97b8404387b96cdbd30ad660f6407799126d26a39ca65729162fd810a99aa", size = 150939, upload-time = "2025-08-09T07:56:44.819Z" }, + { url = "https://files.pythonhosted.org/packages/02/f7/3611b32318b30974131db62b4043f335861d4d9b49adc6d57c1149cc49d4/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ccf600859c183d70eb47e05a44cd80a4ce77394d1ac0f79dbd2dd90a69a3a049", size = 148580, upload-time = "2025-08-09T07:56:46.684Z" }, + { url = "https://files.pythonhosted.org/packages/7e/61/19b36f4bd67f2793ab6a99b979b4e4f3d8fc754cbdffb805335df4337126/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:53cd68b185d98dde4ad8990e56a58dea83a4162161b1ea9272e5c9182ce415e0", size = 159870, upload-time = "2025-08-09T07:56:47.941Z" }, + { url = "https://files.pythonhosted.org/packages/06/57/84722eefdd338c04cf3030ada66889298eaedf3e7a30a624201e0cbe424a/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:30a96e1e1f865f78b030d65241c1ee850cdf422d869e9028e2fc1d5e4db73b92", size = 157797, upload-time = "2025-08-09T07:56:49.756Z" }, + { url = "https://files.pythonhosted.org/packages/72/2a/aff5dd112b2f14bcc3462c312dce5445806bfc8ab3a7328555da95330e4b/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d716a916938e03231e86e43782ca7878fb602a125a91e7acb8b5112e2e96ac16", size = 152224, upload-time = "2025-08-09T07:56:51.369Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8c/9839225320046ed279c6e839d51f028342eb77c91c89b8ef2549f951f3ec/charset_normalizer-3.4.3-cp314-cp314-win32.whl", hash = "sha256:c6dbd0ccdda3a2ba7c2ecd9d77b37f3b5831687d8dc1b6ca5f56a4880cc7b7ce", size = 100086, upload-time = "2025-08-09T07:56:52.722Z" }, + { url = "https://files.pythonhosted.org/packages/ee/7a/36fbcf646e41f710ce0a563c1c9a343c6edf9be80786edeb15b6f62e17db/charset_normalizer-3.4.3-cp314-cp314-win_amd64.whl", hash = "sha256:73dc19b562516fc9bcf6e5d6e596df0b4eb98d87e4f79f3ae71840e6ed21361c", size = 107400, upload-time = "2025-08-09T07:56:55.172Z" }, + { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175, upload-time = "2025-08-09T07:57:26.864Z" }, +] + +[[package]] +name = "click" +version = "8.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "datasets" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, + { name = "filelock" }, + { name = "fsspec", extra = ["http"] }, + { name = "huggingface-hub" }, + { name = "multiprocess" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "xxhash" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e3/9d/348ed92110ba5f9b70b51ca1078d4809767a835aa2b7ce7e74ad2b98323d/datasets-4.0.0.tar.gz", hash = "sha256:9657e7140a9050db13443ba21cb5de185af8af944479b00e7ff1e00a61c8dbf1", size = 569566, upload-time = "2025-07-09T14:35:52.431Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/62/eb8157afb21bd229c864521c1ab4fa8e9b4f1b06bafdd8c4668a7a31b5dd/datasets-4.0.0-py3-none-any.whl", hash = "sha256:7ef95e62025fd122882dbce6cb904c8cd3fbc829de6669a5eb939c77d50e203d", size = 494825, upload-time = "2025-07-09T14:35:50.658Z" }, +] + +[[package]] +name = "dill" +version = "0.3.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847, upload-time = "2024-01-27T23:42:16.145Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" }, +] + +[[package]] +name = "exceptiongroup" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, +] + +[[package]] +name = "fastapi" +version = "0.117.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/7e/d9788300deaf416178f61fb3c2ceb16b7d0dc9f82a08fdb87a5e64ee3cc7/fastapi-0.117.1.tar.gz", hash = "sha256:fb2d42082d22b185f904ca0ecad2e195b851030bd6c5e4c032d1c981240c631a", size = 307155, upload-time = "2025-09-20T20:16:56.663Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/45/d9d3e8eeefbe93be1c50060a9d9a9f366dba66f288bb518a9566a23a8631/fastapi-0.117.1-py3-none-any.whl", hash = "sha256:33c51a0d21cab2b9722d4e56dbb9316f3687155be6b276191790d8da03507552", size = 95959, upload-time = "2025-09-20T20:16:53.661Z" }, +] + +[[package]] +name = "filelock" +version = "3.19.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/40/bb/0ab3e58d22305b6f5440629d20683af28959bf793d98d11950e305c1c326/filelock-3.19.1.tar.gz", hash = "sha256:66eda1888b0171c998b35be2bcc0f6d75c388a7ce20c3f3f37aa8e96c2dddf58", size = 17687, upload-time = "2025-08-14T16:56:03.016Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/14/42b2651a2f46b022ccd948bca9f2d5af0fd8929c4eec235b8d6d844fbe67/filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d", size = 15988, upload-time = "2025-08-14T16:56:01.633Z" }, +] + +[[package]] +name = "files-to-prompt" +version = "0.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/4f/81fc86a88dc9e0cf6ea1ac2c561c0ac48b46d314cbbc2db5c8844b4b448b/files_to_prompt-0.6.tar.gz", hash = "sha256:9af57eecbdb29d3cce034c186493ffc6c1205ea4f5abde6fb32ccb1d96eae40c", size = 12236, upload-time = "2025-02-19T05:58:28.2Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/99/0efff50ce810119d99eaa2fc0c7bbf66e4197e2defb89242f6e848004902/files_to_prompt-0.6-py3-none-any.whl", hash = "sha256:83d9a8b33246a10233218716a5c78034da4f5614748eda2f0ab94f1117801337", size = 10873, upload-time = "2025-02-19T05:58:26.728Z" }, +] + +[[package]] +name = "frozenlist" +version = "1.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/79/b1/b64018016eeb087db503b038296fd782586432b9c077fc5c7839e9cb6ef6/frozenlist-1.7.0.tar.gz", hash = "sha256:2e310d81923c2437ea8670467121cc3e9b0f76d3043cc1d2331d56c7fb7a3a8f", size = 45078, upload-time = "2025-06-09T23:02:35.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/36/0da0a49409f6b47cc2d060dc8c9040b897b5902a8a4e37d9bc1deb11f680/frozenlist-1.7.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc4df77d638aa2ed703b878dd093725b72a824c3c546c076e8fdf276f78ee84a", size = 81304, upload-time = "2025-06-09T22:59:46.226Z" }, + { url = "https://files.pythonhosted.org/packages/77/f0/77c11d13d39513b298e267b22eb6cb559c103d56f155aa9a49097221f0b6/frozenlist-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:716a9973a2cc963160394f701964fe25012600f3d311f60c790400b00e568b61", size = 47735, upload-time = "2025-06-09T22:59:48.133Z" }, + { url = "https://files.pythonhosted.org/packages/37/12/9d07fa18971a44150593de56b2f2947c46604819976784bcf6ea0d5db43b/frozenlist-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0fd1bad056a3600047fb9462cff4c5322cebc59ebf5d0a3725e0ee78955001d", size = 46775, upload-time = "2025-06-09T22:59:49.564Z" }, + { url = "https://files.pythonhosted.org/packages/70/34/f73539227e06288fcd1f8a76853e755b2b48bca6747e99e283111c18bcd4/frozenlist-1.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3789ebc19cb811163e70fe2bd354cea097254ce6e707ae42e56f45e31e96cb8e", size = 224644, upload-time = "2025-06-09T22:59:51.35Z" }, + { url = "https://files.pythonhosted.org/packages/fb/68/c1d9c2f4a6e438e14613bad0f2973567586610cc22dcb1e1241da71de9d3/frozenlist-1.7.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af369aa35ee34f132fcfad5be45fbfcde0e3a5f6a1ec0712857f286b7d20cca9", size = 222125, upload-time = "2025-06-09T22:59:52.884Z" }, + { url = "https://files.pythonhosted.org/packages/b9/d0/98e8f9a515228d708344d7c6986752be3e3192d1795f748c24bcf154ad99/frozenlist-1.7.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac64b6478722eeb7a3313d494f8342ef3478dff539d17002f849101b212ef97c", size = 233455, upload-time = "2025-06-09T22:59:54.74Z" }, + { url = "https://files.pythonhosted.org/packages/79/df/8a11bcec5600557f40338407d3e5bea80376ed1c01a6c0910fcfdc4b8993/frozenlist-1.7.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f89f65d85774f1797239693cef07ad4c97fdd0639544bad9ac4b869782eb1981", size = 227339, upload-time = "2025-06-09T22:59:56.187Z" }, + { url = "https://files.pythonhosted.org/packages/50/82/41cb97d9c9a5ff94438c63cc343eb7980dac4187eb625a51bdfdb7707314/frozenlist-1.7.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1073557c941395fdfcfac13eb2456cb8aad89f9de27bae29fabca8e563b12615", size = 212969, upload-time = "2025-06-09T22:59:57.604Z" }, + { url = "https://files.pythonhosted.org/packages/13/47/f9179ee5ee4f55629e4f28c660b3fdf2775c8bfde8f9c53f2de2d93f52a9/frozenlist-1.7.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ed8d2fa095aae4bdc7fdd80351009a48d286635edffee66bf865e37a9125c50", size = 222862, upload-time = "2025-06-09T22:59:59.498Z" }, + { url = "https://files.pythonhosted.org/packages/1a/52/df81e41ec6b953902c8b7e3a83bee48b195cb0e5ec2eabae5d8330c78038/frozenlist-1.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:24c34bea555fe42d9f928ba0a740c553088500377448febecaa82cc3e88aa1fa", size = 222492, upload-time = "2025-06-09T23:00:01.026Z" }, + { url = "https://files.pythonhosted.org/packages/84/17/30d6ea87fa95a9408245a948604b82c1a4b8b3e153cea596421a2aef2754/frozenlist-1.7.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:69cac419ac6a6baad202c85aaf467b65ac860ac2e7f2ac1686dc40dbb52f6577", size = 238250, upload-time = "2025-06-09T23:00:03.401Z" }, + { url = "https://files.pythonhosted.org/packages/8f/00/ecbeb51669e3c3df76cf2ddd66ae3e48345ec213a55e3887d216eb4fbab3/frozenlist-1.7.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:960d67d0611f4c87da7e2ae2eacf7ea81a5be967861e0c63cf205215afbfac59", size = 218720, upload-time = "2025-06-09T23:00:05.282Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c0/c224ce0e0eb31cc57f67742071bb470ba8246623c1823a7530be0e76164c/frozenlist-1.7.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:41be2964bd4b15bf575e5daee5a5ce7ed3115320fb3c2b71fca05582ffa4dc9e", size = 232585, upload-time = "2025-06-09T23:00:07.962Z" }, + { url = "https://files.pythonhosted.org/packages/55/3c/34cb694abf532f31f365106deebdeac9e45c19304d83cf7d51ebbb4ca4d1/frozenlist-1.7.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:46d84d49e00c9429238a7ce02dc0be8f6d7cd0cd405abd1bebdc991bf27c15bd", size = 234248, upload-time = "2025-06-09T23:00:09.428Z" }, + { url = "https://files.pythonhosted.org/packages/98/c0/2052d8b6cecda2e70bd81299e3512fa332abb6dcd2969b9c80dfcdddbf75/frozenlist-1.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:15900082e886edb37480335d9d518cec978afc69ccbc30bd18610b7c1b22a718", size = 221621, upload-time = "2025-06-09T23:00:11.32Z" }, + { url = "https://files.pythonhosted.org/packages/c5/bf/7dcebae315436903b1d98ffb791a09d674c88480c158aa171958a3ac07f0/frozenlist-1.7.0-cp310-cp310-win32.whl", hash = "sha256:400ddd24ab4e55014bba442d917203c73b2846391dd42ca5e38ff52bb18c3c5e", size = 39578, upload-time = "2025-06-09T23:00:13.526Z" }, + { url = "https://files.pythonhosted.org/packages/8f/5f/f69818f017fa9a3d24d1ae39763e29b7f60a59e46d5f91b9c6b21622f4cd/frozenlist-1.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:6eb93efb8101ef39d32d50bce242c84bcbddb4f7e9febfa7b524532a239b4464", size = 43830, upload-time = "2025-06-09T23:00:14.98Z" }, + { url = "https://files.pythonhosted.org/packages/34/7e/803dde33760128acd393a27eb002f2020ddb8d99d30a44bfbaab31c5f08a/frozenlist-1.7.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:aa51e147a66b2d74de1e6e2cf5921890de6b0f4820b257465101d7f37b49fb5a", size = 82251, upload-time = "2025-06-09T23:00:16.279Z" }, + { url = "https://files.pythonhosted.org/packages/75/a9/9c2c5760b6ba45eae11334db454c189d43d34a4c0b489feb2175e5e64277/frozenlist-1.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9b35db7ce1cd71d36ba24f80f0c9e7cff73a28d7a74e91fe83e23d27c7828750", size = 48183, upload-time = "2025-06-09T23:00:17.698Z" }, + { url = "https://files.pythonhosted.org/packages/47/be/4038e2d869f8a2da165f35a6befb9158c259819be22eeaf9c9a8f6a87771/frozenlist-1.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:34a69a85e34ff37791e94542065c8416c1afbf820b68f720452f636d5fb990cd", size = 47107, upload-time = "2025-06-09T23:00:18.952Z" }, + { url = "https://files.pythonhosted.org/packages/79/26/85314b8a83187c76a37183ceed886381a5f992975786f883472fcb6dc5f2/frozenlist-1.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a646531fa8d82c87fe4bb2e596f23173caec9185bfbca5d583b4ccfb95183e2", size = 237333, upload-time = "2025-06-09T23:00:20.275Z" }, + { url = "https://files.pythonhosted.org/packages/1f/fd/e5b64f7d2c92a41639ffb2ad44a6a82f347787abc0c7df5f49057cf11770/frozenlist-1.7.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:79b2ffbba483f4ed36a0f236ccb85fbb16e670c9238313709638167670ba235f", size = 231724, upload-time = "2025-06-09T23:00:21.705Z" }, + { url = "https://files.pythonhosted.org/packages/20/fb/03395c0a43a5976af4bf7534759d214405fbbb4c114683f434dfdd3128ef/frozenlist-1.7.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a26f205c9ca5829cbf82bb2a84b5c36f7184c4316617d7ef1b271a56720d6b30", size = 245842, upload-time = "2025-06-09T23:00:23.148Z" }, + { url = "https://files.pythonhosted.org/packages/d0/15/c01c8e1dffdac5d9803507d824f27aed2ba76b6ed0026fab4d9866e82f1f/frozenlist-1.7.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bcacfad3185a623fa11ea0e0634aac7b691aa925d50a440f39b458e41c561d98", size = 239767, upload-time = "2025-06-09T23:00:25.103Z" }, + { url = "https://files.pythonhosted.org/packages/14/99/3f4c6fe882c1f5514b6848aa0a69b20cb5e5d8e8f51a339d48c0e9305ed0/frozenlist-1.7.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:72c1b0fe8fe451b34f12dce46445ddf14bd2a5bcad7e324987194dc8e3a74c86", size = 224130, upload-time = "2025-06-09T23:00:27.061Z" }, + { url = "https://files.pythonhosted.org/packages/4d/83/220a374bd7b2aeba9d0725130665afe11de347d95c3620b9b82cc2fcab97/frozenlist-1.7.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61d1a5baeaac6c0798ff6edfaeaa00e0e412d49946c53fae8d4b8e8b3566c4ae", size = 235301, upload-time = "2025-06-09T23:00:29.02Z" }, + { url = "https://files.pythonhosted.org/packages/03/3c/3e3390d75334a063181625343e8daab61b77e1b8214802cc4e8a1bb678fc/frozenlist-1.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7edf5c043c062462f09b6820de9854bf28cc6cc5b6714b383149745e287181a8", size = 234606, upload-time = "2025-06-09T23:00:30.514Z" }, + { url = "https://files.pythonhosted.org/packages/23/1e/58232c19608b7a549d72d9903005e2d82488f12554a32de2d5fb59b9b1ba/frozenlist-1.7.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d50ac7627b3a1bd2dcef6f9da89a772694ec04d9a61b66cf87f7d9446b4a0c31", size = 248372, upload-time = "2025-06-09T23:00:31.966Z" }, + { url = "https://files.pythonhosted.org/packages/c0/a4/e4a567e01702a88a74ce8a324691e62a629bf47d4f8607f24bf1c7216e7f/frozenlist-1.7.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ce48b2fece5aeb45265bb7a58259f45027db0abff478e3077e12b05b17fb9da7", size = 229860, upload-time = "2025-06-09T23:00:33.375Z" }, + { url = "https://files.pythonhosted.org/packages/73/a6/63b3374f7d22268b41a9db73d68a8233afa30ed164c46107b33c4d18ecdd/frozenlist-1.7.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:fe2365ae915a1fafd982c146754e1de6ab3478def8a59c86e1f7242d794f97d5", size = 245893, upload-time = "2025-06-09T23:00:35.002Z" }, + { url = "https://files.pythonhosted.org/packages/6d/eb/d18b3f6e64799a79673c4ba0b45e4cfbe49c240edfd03a68be20002eaeaa/frozenlist-1.7.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:45a6f2fdbd10e074e8814eb98b05292f27bad7d1883afbe009d96abdcf3bc898", size = 246323, upload-time = "2025-06-09T23:00:36.468Z" }, + { url = "https://files.pythonhosted.org/packages/5a/f5/720f3812e3d06cd89a1d5db9ff6450088b8f5c449dae8ffb2971a44da506/frozenlist-1.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:21884e23cffabb157a9dd7e353779077bf5b8f9a58e9b262c6caad2ef5f80a56", size = 233149, upload-time = "2025-06-09T23:00:37.963Z" }, + { url = "https://files.pythonhosted.org/packages/69/68/03efbf545e217d5db8446acfd4c447c15b7c8cf4dbd4a58403111df9322d/frozenlist-1.7.0-cp311-cp311-win32.whl", hash = "sha256:284d233a8953d7b24f9159b8a3496fc1ddc00f4db99c324bd5fb5f22d8698ea7", size = 39565, upload-time = "2025-06-09T23:00:39.753Z" }, + { url = "https://files.pythonhosted.org/packages/58/17/fe61124c5c333ae87f09bb67186d65038834a47d974fc10a5fadb4cc5ae1/frozenlist-1.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:387cbfdcde2f2353f19c2f66bbb52406d06ed77519ac7ee21be0232147c2592d", size = 44019, upload-time = "2025-06-09T23:00:40.988Z" }, + { url = "https://files.pythonhosted.org/packages/ef/a2/c8131383f1e66adad5f6ecfcce383d584ca94055a34d683bbb24ac5f2f1c/frozenlist-1.7.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3dbf9952c4bb0e90e98aec1bd992b3318685005702656bc6f67c1a32b76787f2", size = 81424, upload-time = "2025-06-09T23:00:42.24Z" }, + { url = "https://files.pythonhosted.org/packages/4c/9d/02754159955088cb52567337d1113f945b9e444c4960771ea90eb73de8db/frozenlist-1.7.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:1f5906d3359300b8a9bb194239491122e6cf1444c2efb88865426f170c262cdb", size = 47952, upload-time = "2025-06-09T23:00:43.481Z" }, + { url = "https://files.pythonhosted.org/packages/01/7a/0046ef1bd6699b40acd2067ed6d6670b4db2f425c56980fa21c982c2a9db/frozenlist-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3dabd5a8f84573c8d10d8859a50ea2dec01eea372031929871368c09fa103478", size = 46688, upload-time = "2025-06-09T23:00:44.793Z" }, + { url = "https://files.pythonhosted.org/packages/d6/a2/a910bafe29c86997363fb4c02069df4ff0b5bc39d33c5198b4e9dd42d8f8/frozenlist-1.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa57daa5917f1738064f302bf2626281a1cb01920c32f711fbc7bc36111058a8", size = 243084, upload-time = "2025-06-09T23:00:46.125Z" }, + { url = "https://files.pythonhosted.org/packages/64/3e/5036af9d5031374c64c387469bfcc3af537fc0f5b1187d83a1cf6fab1639/frozenlist-1.7.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c193dda2b6d49f4c4398962810fa7d7c78f032bf45572b3e04dd5249dff27e08", size = 233524, upload-time = "2025-06-09T23:00:47.73Z" }, + { url = "https://files.pythonhosted.org/packages/06/39/6a17b7c107a2887e781a48ecf20ad20f1c39d94b2a548c83615b5b879f28/frozenlist-1.7.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe2b675cf0aaa6d61bf8fbffd3c274b3c9b7b1623beb3809df8a81399a4a9c4", size = 248493, upload-time = "2025-06-09T23:00:49.742Z" }, + { url = "https://files.pythonhosted.org/packages/be/00/711d1337c7327d88c44d91dd0f556a1c47fb99afc060ae0ef66b4d24793d/frozenlist-1.7.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8fc5d5cda37f62b262405cf9652cf0856839c4be8ee41be0afe8858f17f4c94b", size = 244116, upload-time = "2025-06-09T23:00:51.352Z" }, + { url = "https://files.pythonhosted.org/packages/24/fe/74e6ec0639c115df13d5850e75722750adabdc7de24e37e05a40527ca539/frozenlist-1.7.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0d5ce521d1dd7d620198829b87ea002956e4319002ef0bc8d3e6d045cb4646e", size = 224557, upload-time = "2025-06-09T23:00:52.855Z" }, + { url = "https://files.pythonhosted.org/packages/8d/db/48421f62a6f77c553575201e89048e97198046b793f4a089c79a6e3268bd/frozenlist-1.7.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:488d0a7d6a0008ca0db273c542098a0fa9e7dfaa7e57f70acef43f32b3f69dca", size = 241820, upload-time = "2025-06-09T23:00:54.43Z" }, + { url = "https://files.pythonhosted.org/packages/1d/fa/cb4a76bea23047c8462976ea7b7a2bf53997a0ca171302deae9d6dd12096/frozenlist-1.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:15a7eaba63983d22c54d255b854e8108e7e5f3e89f647fc854bd77a237e767df", size = 236542, upload-time = "2025-06-09T23:00:56.409Z" }, + { url = "https://files.pythonhosted.org/packages/5d/32/476a4b5cfaa0ec94d3f808f193301debff2ea42288a099afe60757ef6282/frozenlist-1.7.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:1eaa7e9c6d15df825bf255649e05bd8a74b04a4d2baa1ae46d9c2d00b2ca2cb5", size = 249350, upload-time = "2025-06-09T23:00:58.468Z" }, + { url = "https://files.pythonhosted.org/packages/8d/ba/9a28042f84a6bf8ea5dbc81cfff8eaef18d78b2a1ad9d51c7bc5b029ad16/frozenlist-1.7.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4389e06714cfa9d47ab87f784a7c5be91d3934cd6e9a7b85beef808297cc025", size = 225093, upload-time = "2025-06-09T23:01:00.015Z" }, + { url = "https://files.pythonhosted.org/packages/bc/29/3a32959e68f9cf000b04e79ba574527c17e8842e38c91d68214a37455786/frozenlist-1.7.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:73bd45e1488c40b63fe5a7df892baf9e2a4d4bb6409a2b3b78ac1c6236178e01", size = 245482, upload-time = "2025-06-09T23:01:01.474Z" }, + { url = "https://files.pythonhosted.org/packages/80/e8/edf2f9e00da553f07f5fa165325cfc302dead715cab6ac8336a5f3d0adc2/frozenlist-1.7.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:99886d98e1643269760e5fe0df31e5ae7050788dd288947f7f007209b8c33f08", size = 249590, upload-time = "2025-06-09T23:01:02.961Z" }, + { url = "https://files.pythonhosted.org/packages/1c/80/9a0eb48b944050f94cc51ee1c413eb14a39543cc4f760ed12657a5a3c45a/frozenlist-1.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:290a172aae5a4c278c6da8a96222e6337744cd9c77313efe33d5670b9f65fc43", size = 237785, upload-time = "2025-06-09T23:01:05.095Z" }, + { url = "https://files.pythonhosted.org/packages/f3/74/87601e0fb0369b7a2baf404ea921769c53b7ae00dee7dcfe5162c8c6dbf0/frozenlist-1.7.0-cp312-cp312-win32.whl", hash = "sha256:426c7bc70e07cfebc178bc4c2bf2d861d720c4fff172181eeb4a4c41d4ca2ad3", size = 39487, upload-time = "2025-06-09T23:01:06.54Z" }, + { url = "https://files.pythonhosted.org/packages/0b/15/c026e9a9fc17585a9d461f65d8593d281fedf55fbf7eb53f16c6df2392f9/frozenlist-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:563b72efe5da92e02eb68c59cb37205457c977aa7a449ed1b37e6939e5c47c6a", size = 43874, upload-time = "2025-06-09T23:01:07.752Z" }, + { url = "https://files.pythonhosted.org/packages/24/90/6b2cebdabdbd50367273c20ff6b57a3dfa89bd0762de02c3a1eb42cb6462/frozenlist-1.7.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee80eeda5e2a4e660651370ebffd1286542b67e268aa1ac8d6dbe973120ef7ee", size = 79791, upload-time = "2025-06-09T23:01:09.368Z" }, + { url = "https://files.pythonhosted.org/packages/83/2e/5b70b6a3325363293fe5fc3ae74cdcbc3e996c2a11dde2fd9f1fb0776d19/frozenlist-1.7.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d1a81c85417b914139e3a9b995d4a1c84559afc839a93cf2cb7f15e6e5f6ed2d", size = 47165, upload-time = "2025-06-09T23:01:10.653Z" }, + { url = "https://files.pythonhosted.org/packages/f4/25/a0895c99270ca6966110f4ad98e87e5662eab416a17e7fd53c364bf8b954/frozenlist-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cbb65198a9132ebc334f237d7b0df163e4de83fb4f2bdfe46c1e654bdb0c5d43", size = 45881, upload-time = "2025-06-09T23:01:12.296Z" }, + { url = "https://files.pythonhosted.org/packages/19/7c/71bb0bbe0832793c601fff68cd0cf6143753d0c667f9aec93d3c323f4b55/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dab46c723eeb2c255a64f9dc05b8dd601fde66d6b19cdb82b2e09cc6ff8d8b5d", size = 232409, upload-time = "2025-06-09T23:01:13.641Z" }, + { url = "https://files.pythonhosted.org/packages/c0/45/ed2798718910fe6eb3ba574082aaceff4528e6323f9a8570be0f7028d8e9/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6aeac207a759d0dedd2e40745575ae32ab30926ff4fa49b1635def65806fddee", size = 225132, upload-time = "2025-06-09T23:01:15.264Z" }, + { url = "https://files.pythonhosted.org/packages/ba/e2/8417ae0f8eacb1d071d4950f32f229aa6bf68ab69aab797b72a07ea68d4f/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bd8c4e58ad14b4fa7802b8be49d47993182fdd4023393899632c88fd8cd994eb", size = 237638, upload-time = "2025-06-09T23:01:16.752Z" }, + { url = "https://files.pythonhosted.org/packages/f8/b7/2ace5450ce85f2af05a871b8c8719b341294775a0a6c5585d5e6170f2ce7/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04fb24d104f425da3540ed83cbfc31388a586a7696142004c577fa61c6298c3f", size = 233539, upload-time = "2025-06-09T23:01:18.202Z" }, + { url = "https://files.pythonhosted.org/packages/46/b9/6989292c5539553dba63f3c83dc4598186ab2888f67c0dc1d917e6887db6/frozenlist-1.7.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6a5c505156368e4ea6b53b5ac23c92d7edc864537ff911d2fb24c140bb175e60", size = 215646, upload-time = "2025-06-09T23:01:19.649Z" }, + { url = "https://files.pythonhosted.org/packages/72/31/bc8c5c99c7818293458fe745dab4fd5730ff49697ccc82b554eb69f16a24/frozenlist-1.7.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8bd7eb96a675f18aa5c553eb7ddc24a43c8c18f22e1f9925528128c052cdbe00", size = 232233, upload-time = "2025-06-09T23:01:21.175Z" }, + { url = "https://files.pythonhosted.org/packages/59/52/460db4d7ba0811b9ccb85af996019f5d70831f2f5f255f7cc61f86199795/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:05579bf020096fe05a764f1f84cd104a12f78eaab68842d036772dc6d4870b4b", size = 227996, upload-time = "2025-06-09T23:01:23.098Z" }, + { url = "https://files.pythonhosted.org/packages/ba/c9/f4b39e904c03927b7ecf891804fd3b4df3db29b9e487c6418e37988d6e9d/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:376b6222d114e97eeec13d46c486facd41d4f43bab626b7c3f6a8b4e81a5192c", size = 242280, upload-time = "2025-06-09T23:01:24.808Z" }, + { url = "https://files.pythonhosted.org/packages/b8/33/3f8d6ced42f162d743e3517781566b8481322be321b486d9d262adf70bfb/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:0aa7e176ebe115379b5b1c95b4096fb1c17cce0847402e227e712c27bdb5a949", size = 217717, upload-time = "2025-06-09T23:01:26.28Z" }, + { url = "https://files.pythonhosted.org/packages/3e/e8/ad683e75da6ccef50d0ab0c2b2324b32f84fc88ceee778ed79b8e2d2fe2e/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3fbba20e662b9c2130dc771e332a99eff5da078b2b2648153a40669a6d0e36ca", size = 236644, upload-time = "2025-06-09T23:01:27.887Z" }, + { url = "https://files.pythonhosted.org/packages/b2/14/8d19ccdd3799310722195a72ac94ddc677541fb4bef4091d8e7775752360/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:f3f4410a0a601d349dd406b5713fec59b4cee7e71678d5b17edda7f4655a940b", size = 238879, upload-time = "2025-06-09T23:01:29.524Z" }, + { url = "https://files.pythonhosted.org/packages/ce/13/c12bf657494c2fd1079a48b2db49fa4196325909249a52d8f09bc9123fd7/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e2cdfaaec6a2f9327bf43c933c0319a7c429058e8537c508964a133dffee412e", size = 232502, upload-time = "2025-06-09T23:01:31.287Z" }, + { url = "https://files.pythonhosted.org/packages/d7/8b/e7f9dfde869825489382bc0d512c15e96d3964180c9499efcec72e85db7e/frozenlist-1.7.0-cp313-cp313-win32.whl", hash = "sha256:5fc4df05a6591c7768459caba1b342d9ec23fa16195e744939ba5914596ae3e1", size = 39169, upload-time = "2025-06-09T23:01:35.503Z" }, + { url = "https://files.pythonhosted.org/packages/35/89/a487a98d94205d85745080a37860ff5744b9820a2c9acbcdd9440bfddf98/frozenlist-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:52109052b9791a3e6b5d1b65f4b909703984b770694d3eb64fad124c835d7cba", size = 43219, upload-time = "2025-06-09T23:01:36.784Z" }, + { url = "https://files.pythonhosted.org/packages/56/d5/5c4cf2319a49eddd9dd7145e66c4866bdc6f3dbc67ca3d59685149c11e0d/frozenlist-1.7.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:a6f86e4193bb0e235ef6ce3dde5cbabed887e0b11f516ce8a0f4d3b33078ec2d", size = 84345, upload-time = "2025-06-09T23:01:38.295Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/ec2c1e1dc16b85bc9d526009961953df9cec8481b6886debb36ec9107799/frozenlist-1.7.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:82d664628865abeb32d90ae497fb93df398a69bb3434463d172b80fc25b0dd7d", size = 48880, upload-time = "2025-06-09T23:01:39.887Z" }, + { url = "https://files.pythonhosted.org/packages/69/86/f9596807b03de126e11e7d42ac91e3d0b19a6599c714a1989a4e85eeefc4/frozenlist-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:912a7e8375a1c9a68325a902f3953191b7b292aa3c3fb0d71a216221deca460b", size = 48498, upload-time = "2025-06-09T23:01:41.318Z" }, + { url = "https://files.pythonhosted.org/packages/5e/cb/df6de220f5036001005f2d726b789b2c0b65f2363b104bbc16f5be8084f8/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9537c2777167488d539bc5de2ad262efc44388230e5118868e172dd4a552b146", size = 292296, upload-time = "2025-06-09T23:01:42.685Z" }, + { url = "https://files.pythonhosted.org/packages/83/1f/de84c642f17c8f851a2905cee2dae401e5e0daca9b5ef121e120e19aa825/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:f34560fb1b4c3e30ba35fa9a13894ba39e5acfc5f60f57d8accde65f46cc5e74", size = 273103, upload-time = "2025-06-09T23:01:44.166Z" }, + { url = "https://files.pythonhosted.org/packages/88/3c/c840bfa474ba3fa13c772b93070893c6e9d5c0350885760376cbe3b6c1b3/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:acd03d224b0175f5a850edc104ac19040d35419eddad04e7cf2d5986d98427f1", size = 292869, upload-time = "2025-06-09T23:01:45.681Z" }, + { url = "https://files.pythonhosted.org/packages/a6/1c/3efa6e7d5a39a1d5ef0abeb51c48fb657765794a46cf124e5aca2c7a592c/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2038310bc582f3d6a09b3816ab01737d60bf7b1ec70f5356b09e84fb7408ab1", size = 291467, upload-time = "2025-06-09T23:01:47.234Z" }, + { url = "https://files.pythonhosted.org/packages/4f/00/d5c5e09d4922c395e2f2f6b79b9a20dab4b67daaf78ab92e7729341f61f6/frozenlist-1.7.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8c05e4c8e5f36e5e088caa1bf78a687528f83c043706640a92cb76cd6999384", size = 266028, upload-time = "2025-06-09T23:01:48.819Z" }, + { url = "https://files.pythonhosted.org/packages/4e/27/72765be905619dfde25a7f33813ac0341eb6b076abede17a2e3fbfade0cb/frozenlist-1.7.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:765bb588c86e47d0b68f23c1bee323d4b703218037765dcf3f25c838c6fecceb", size = 284294, upload-time = "2025-06-09T23:01:50.394Z" }, + { url = "https://files.pythonhosted.org/packages/88/67/c94103a23001b17808eb7dd1200c156bb69fb68e63fcf0693dde4cd6228c/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:32dc2e08c67d86d0969714dd484fd60ff08ff81d1a1e40a77dd34a387e6ebc0c", size = 281898, upload-time = "2025-06-09T23:01:52.234Z" }, + { url = "https://files.pythonhosted.org/packages/42/34/a3e2c00c00f9e2a9db5653bca3fec306349e71aff14ae45ecc6d0951dd24/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:c0303e597eb5a5321b4de9c68e9845ac8f290d2ab3f3e2c864437d3c5a30cd65", size = 290465, upload-time = "2025-06-09T23:01:53.788Z" }, + { url = "https://files.pythonhosted.org/packages/bb/73/f89b7fbce8b0b0c095d82b008afd0590f71ccb3dee6eee41791cf8cd25fd/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:a47f2abb4e29b3a8d0b530f7c3598badc6b134562b1a5caee867f7c62fee51e3", size = 266385, upload-time = "2025-06-09T23:01:55.769Z" }, + { url = "https://files.pythonhosted.org/packages/cd/45/e365fdb554159462ca12df54bc59bfa7a9a273ecc21e99e72e597564d1ae/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:3d688126c242a6fabbd92e02633414d40f50bb6002fa4cf995a1d18051525657", size = 288771, upload-time = "2025-06-09T23:01:57.4Z" }, + { url = "https://files.pythonhosted.org/packages/00/11/47b6117002a0e904f004d70ec5194fe9144f117c33c851e3d51c765962d0/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:4e7e9652b3d367c7bd449a727dc79d5043f48b88d0cbfd4f9f1060cf2b414104", size = 288206, upload-time = "2025-06-09T23:01:58.936Z" }, + { url = "https://files.pythonhosted.org/packages/40/37/5f9f3c3fd7f7746082ec67bcdc204db72dad081f4f83a503d33220a92973/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1a85e345b4c43db8b842cab1feb41be5cc0b10a1830e6295b69d7310f99becaf", size = 282620, upload-time = "2025-06-09T23:02:00.493Z" }, + { url = "https://files.pythonhosted.org/packages/0b/31/8fbc5af2d183bff20f21aa743b4088eac4445d2bb1cdece449ae80e4e2d1/frozenlist-1.7.0-cp313-cp313t-win32.whl", hash = "sha256:3a14027124ddb70dfcee5148979998066897e79f89f64b13328595c4bdf77c81", size = 43059, upload-time = "2025-06-09T23:02:02.072Z" }, + { url = "https://files.pythonhosted.org/packages/bb/ed/41956f52105b8dbc26e457c5705340c67c8cc2b79f394b79bffc09d0e938/frozenlist-1.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:3bf8010d71d4507775f658e9823210b7427be36625b387221642725b515dcf3e", size = 47516, upload-time = "2025-06-09T23:02:03.779Z" }, + { url = "https://files.pythonhosted.org/packages/ee/45/b82e3c16be2182bff01179db177fe144d58b5dc787a7d4492c6ed8b9317f/frozenlist-1.7.0-py3-none-any.whl", hash = "sha256:9a5af342e34f7e97caf8c995864c7a396418ae2859cc6fdf1b1073020d516a7e", size = 13106, upload-time = "2025-06-09T23:02:34.204Z" }, +] + +[[package]] +name = "fsspec" +version = "2025.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/34/f4/5721faf47b8c499e776bc34c6a8fc17efdf7fdef0b00f398128bc5dcb4ac/fsspec-2025.3.0.tar.gz", hash = "sha256:a935fd1ea872591f2b5148907d103488fc523295e6c64b835cfad8c3eca44972", size = 298491, upload-time = "2025-03-07T21:47:56.461Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615, upload-time = "2025-03-07T21:47:54.809Z" }, +] + +[package.optional-dependencies] +http = [ + { name = "aiohttp" }, +] + +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.45" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/c8/dd58967d119baab745caec2f9d853297cec1989ec1d63f677d3880632b88/gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c", size = 215076, upload-time = "2025-07-24T03:45:54.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, +] + +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "hf-xet" +version = "1.1.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/23/0f/5b60fc28ee7f8cc17a5114a584fd6b86e11c3e0a6e142a7f97a161e9640a/hf_xet-1.1.9.tar.gz", hash = "sha256:c99073ce404462e909f1d5839b2d14a3827b8fe75ed8aed551ba6609c026c803", size = 484242, upload-time = "2025-08-27T23:05:19.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/12/56e1abb9a44cdef59a411fe8a8673313195711b5ecce27880eb9c8fa90bd/hf_xet-1.1.9-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:a3b6215f88638dd7a6ff82cb4e738dcbf3d863bf667997c093a3c990337d1160", size = 2762553, upload-time = "2025-08-27T23:05:15.153Z" }, + { url = "https://files.pythonhosted.org/packages/3a/e6/2d0d16890c5f21b862f5df3146519c182e7f0ae49b4b4bf2bd8a40d0b05e/hf_xet-1.1.9-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:9b486de7a64a66f9a172f4b3e0dfe79c9f0a93257c501296a2521a13495a698a", size = 2623216, upload-time = "2025-08-27T23:05:13.778Z" }, + { url = "https://files.pythonhosted.org/packages/81/42/7e6955cf0621e87491a1fb8cad755d5c2517803cea174229b0ec00ff0166/hf_xet-1.1.9-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4c5a840c2c4e6ec875ed13703a60e3523bc7f48031dfd750923b2a4d1a5fc3c", size = 3186789, upload-time = "2025-08-27T23:05:12.368Z" }, + { url = "https://files.pythonhosted.org/packages/df/8b/759233bce05457f5f7ec062d63bbfd2d0c740b816279eaaa54be92aa452a/hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:96a6139c9e44dad1c52c52520db0fffe948f6bce487cfb9d69c125f254bb3790", size = 3088747, upload-time = "2025-08-27T23:05:10.439Z" }, + { url = "https://files.pythonhosted.org/packages/6c/3c/28cc4db153a7601a996985bcb564f7b8f5b9e1a706c7537aad4b4809f358/hf_xet-1.1.9-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ad1022e9a998e784c97b2173965d07fe33ee26e4594770b7785a8cc8f922cd95", size = 3251429, upload-time = "2025-08-27T23:05:16.471Z" }, + { url = "https://files.pythonhosted.org/packages/84/17/7caf27a1d101bfcb05be85850d4aa0a265b2e1acc2d4d52a48026ef1d299/hf_xet-1.1.9-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:86754c2d6d5afb11b0a435e6e18911a4199262fe77553f8c50d75e21242193ea", size = 3354643, upload-time = "2025-08-27T23:05:17.828Z" }, + { url = "https://files.pythonhosted.org/packages/cd/50/0c39c9eed3411deadcc98749a6699d871b822473f55fe472fad7c01ec588/hf_xet-1.1.9-cp37-abi3-win_amd64.whl", hash = "sha256:5aad3933de6b725d61d51034e04174ed1dce7a57c63d530df0014dea15a40127", size = 2804797, upload-time = "2025-08-27T23:05:20.77Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "0.34.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/45/c9/bdbe19339f76d12985bc03572f330a01a93c04dffecaaea3061bdd7fb892/huggingface_hub-0.34.4.tar.gz", hash = "sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c", size = 459768, upload-time = "2025-08-08T09:14:52.365Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452, upload-time = "2025-08-08T09:14:50.159Z" }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537, upload-time = "2024-10-18T15:21:54.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/90/d08277ce111dd22f77149fd1a5d4653eeb3b3eaacbdfcbae5afb2600eebd/MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8", size = 14357, upload-time = "2024-10-18T15:20:51.44Z" }, + { url = "https://files.pythonhosted.org/packages/04/e1/6e2194baeae0bca1fae6629dc0cbbb968d4d941469cbab11a3872edff374/MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158", size = 12393, upload-time = "2024-10-18T15:20:52.426Z" }, + { url = "https://files.pythonhosted.org/packages/1d/69/35fa85a8ece0a437493dc61ce0bb6d459dcba482c34197e3efc829aa357f/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579", size = 21732, upload-time = "2024-10-18T15:20:53.578Z" }, + { url = "https://files.pythonhosted.org/packages/22/35/137da042dfb4720b638d2937c38a9c2df83fe32d20e8c8f3185dbfef05f7/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d", size = 20866, upload-time = "2024-10-18T15:20:55.06Z" }, + { url = "https://files.pythonhosted.org/packages/29/28/6d029a903727a1b62edb51863232152fd335d602def598dade38996887f0/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb", size = 20964, upload-time = "2024-10-18T15:20:55.906Z" }, + { url = "https://files.pythonhosted.org/packages/cc/cd/07438f95f83e8bc028279909d9c9bd39e24149b0d60053a97b2bc4f8aa51/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b", size = 21977, upload-time = "2024-10-18T15:20:57.189Z" }, + { url = "https://files.pythonhosted.org/packages/29/01/84b57395b4cc062f9c4c55ce0df7d3108ca32397299d9df00fedd9117d3d/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c", size = 21366, upload-time = "2024-10-18T15:20:58.235Z" }, + { url = "https://files.pythonhosted.org/packages/bd/6e/61ebf08d8940553afff20d1fb1ba7294b6f8d279df9fd0c0db911b4bbcfd/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171", size = 21091, upload-time = "2024-10-18T15:20:59.235Z" }, + { url = "https://files.pythonhosted.org/packages/11/23/ffbf53694e8c94ebd1e7e491de185124277964344733c45481f32ede2499/MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50", size = 15065, upload-time = "2024-10-18T15:21:00.307Z" }, + { url = "https://files.pythonhosted.org/packages/44/06/e7175d06dd6e9172d4a69a72592cb3f7a996a9c396eee29082826449bbc3/MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a", size = 15514, upload-time = "2024-10-18T15:21:01.122Z" }, + { url = "https://files.pythonhosted.org/packages/6b/28/bbf83e3f76936960b850435576dd5e67034e200469571be53f69174a2dfd/MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d", size = 14353, upload-time = "2024-10-18T15:21:02.187Z" }, + { url = "https://files.pythonhosted.org/packages/6c/30/316d194b093cde57d448a4c3209f22e3046c5bb2fb0820b118292b334be7/MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93", size = 12392, upload-time = "2024-10-18T15:21:02.941Z" }, + { url = "https://files.pythonhosted.org/packages/f2/96/9cdafba8445d3a53cae530aaf83c38ec64c4d5427d975c974084af5bc5d2/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832", size = 23984, upload-time = "2024-10-18T15:21:03.953Z" }, + { url = "https://files.pythonhosted.org/packages/f1/a4/aefb044a2cd8d7334c8a47d3fb2c9f328ac48cb349468cc31c20b539305f/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84", size = 23120, upload-time = "2024-10-18T15:21:06.495Z" }, + { url = "https://files.pythonhosted.org/packages/8d/21/5e4851379f88f3fad1de30361db501300d4f07bcad047d3cb0449fc51f8c/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca", size = 23032, upload-time = "2024-10-18T15:21:07.295Z" }, + { url = "https://files.pythonhosted.org/packages/00/7b/e92c64e079b2d0d7ddf69899c98842f3f9a60a1ae72657c89ce2655c999d/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798", size = 24057, upload-time = "2024-10-18T15:21:08.073Z" }, + { url = "https://files.pythonhosted.org/packages/f9/ac/46f960ca323037caa0a10662ef97d0a4728e890334fc156b9f9e52bcc4ca/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e", size = 23359, upload-time = "2024-10-18T15:21:09.318Z" }, + { url = "https://files.pythonhosted.org/packages/69/84/83439e16197337b8b14b6a5b9c2105fff81d42c2a7c5b58ac7b62ee2c3b1/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4", size = 23306, upload-time = "2024-10-18T15:21:10.185Z" }, + { url = "https://files.pythonhosted.org/packages/9a/34/a15aa69f01e2181ed8d2b685c0d2f6655d5cca2c4db0ddea775e631918cd/MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d", size = 15094, upload-time = "2024-10-18T15:21:11.005Z" }, + { url = "https://files.pythonhosted.org/packages/da/b8/3a3bd761922d416f3dc5d00bfbed11f66b1ab89a0c2b6e887240a30b0f6b/MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b", size = 15521, upload-time = "2024-10-18T15:21:12.911Z" }, + { url = "https://files.pythonhosted.org/packages/22/09/d1f21434c97fc42f09d290cbb6350d44eb12f09cc62c9476effdb33a18aa/MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf", size = 14274, upload-time = "2024-10-18T15:21:13.777Z" }, + { url = "https://files.pythonhosted.org/packages/6b/b0/18f76bba336fa5aecf79d45dcd6c806c280ec44538b3c13671d49099fdd0/MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225", size = 12348, upload-time = "2024-10-18T15:21:14.822Z" }, + { url = "https://files.pythonhosted.org/packages/e0/25/dd5c0f6ac1311e9b40f4af06c78efde0f3b5cbf02502f8ef9501294c425b/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028", size = 24149, upload-time = "2024-10-18T15:21:15.642Z" }, + { url = "https://files.pythonhosted.org/packages/f3/f0/89e7aadfb3749d0f52234a0c8c7867877876e0a20b60e2188e9850794c17/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8", size = 23118, upload-time = "2024-10-18T15:21:17.133Z" }, + { url = "https://files.pythonhosted.org/packages/d5/da/f2eeb64c723f5e3777bc081da884b414671982008c47dcc1873d81f625b6/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c", size = 22993, upload-time = "2024-10-18T15:21:18.064Z" }, + { url = "https://files.pythonhosted.org/packages/da/0e/1f32af846df486dce7c227fe0f2398dc7e2e51d4a370508281f3c1c5cddc/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557", size = 24178, upload-time = "2024-10-18T15:21:18.859Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f6/bb3ca0532de8086cbff5f06d137064c8410d10779c4c127e0e47d17c0b71/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22", size = 23319, upload-time = "2024-10-18T15:21:19.671Z" }, + { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352, upload-time = "2024-10-18T15:21:20.971Z" }, + { url = "https://files.pythonhosted.org/packages/51/ae/97827349d3fcffee7e184bdf7f41cd6b88d9919c80f0263ba7acd1bbcb18/MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30", size = 15097, upload-time = "2024-10-18T15:21:22.646Z" }, + { url = "https://files.pythonhosted.org/packages/c1/80/a61f99dc3a936413c3ee4e1eecac96c0da5ed07ad56fd975f1a9da5bc630/MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87", size = 15601, upload-time = "2024-10-18T15:21:23.499Z" }, + { url = "https://files.pythonhosted.org/packages/83/0e/67eb10a7ecc77a0c2bbe2b0235765b98d164d81600746914bebada795e97/MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd", size = 14274, upload-time = "2024-10-18T15:21:24.577Z" }, + { url = "https://files.pythonhosted.org/packages/2b/6d/9409f3684d3335375d04e5f05744dfe7e9f120062c9857df4ab490a1031a/MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430", size = 12352, upload-time = "2024-10-18T15:21:25.382Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f5/6eadfcd3885ea85fe2a7c128315cc1bb7241e1987443d78c8fe712d03091/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094", size = 24122, upload-time = "2024-10-18T15:21:26.199Z" }, + { url = "https://files.pythonhosted.org/packages/0c/91/96cf928db8236f1bfab6ce15ad070dfdd02ed88261c2afafd4b43575e9e9/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396", size = 23085, upload-time = "2024-10-18T15:21:27.029Z" }, + { url = "https://files.pythonhosted.org/packages/c2/cf/c9d56af24d56ea04daae7ac0940232d31d5a8354f2b457c6d856b2057d69/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79", size = 22978, upload-time = "2024-10-18T15:21:27.846Z" }, + { url = "https://files.pythonhosted.org/packages/2a/9f/8619835cd6a711d6272d62abb78c033bda638fdc54c4e7f4272cf1c0962b/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a", size = 24208, upload-time = "2024-10-18T15:21:28.744Z" }, + { url = "https://files.pythonhosted.org/packages/f9/bf/176950a1792b2cd2102b8ffeb5133e1ed984547b75db47c25a67d3359f77/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca", size = 23357, upload-time = "2024-10-18T15:21:29.545Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4f/9a02c1d335caabe5c4efb90e1b6e8ee944aa245c1aaaab8e8a618987d816/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c", size = 23344, upload-time = "2024-10-18T15:21:30.366Z" }, + { url = "https://files.pythonhosted.org/packages/ee/55/c271b57db36f748f0e04a759ace9f8f759ccf22b4960c270c78a394f58be/MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1", size = 15101, upload-time = "2024-10-18T15:21:31.207Z" }, + { url = "https://files.pythonhosted.org/packages/29/88/07df22d2dd4df40aba9f3e402e6dc1b8ee86297dddbad4872bd5e7b0094f/MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f", size = 15603, upload-time = "2024-10-18T15:21:32.032Z" }, + { url = "https://files.pythonhosted.org/packages/62/6a/8b89d24db2d32d433dffcd6a8779159da109842434f1dd2f6e71f32f738c/MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c", size = 14510, upload-time = "2024-10-18T15:21:33.625Z" }, + { url = "https://files.pythonhosted.org/packages/7a/06/a10f955f70a2e5a9bf78d11a161029d278eeacbd35ef806c3fd17b13060d/MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb", size = 12486, upload-time = "2024-10-18T15:21:34.611Z" }, + { url = "https://files.pythonhosted.org/packages/34/cf/65d4a571869a1a9078198ca28f39fba5fbb910f952f9dbc5220afff9f5e6/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c", size = 25480, upload-time = "2024-10-18T15:21:35.398Z" }, + { url = "https://files.pythonhosted.org/packages/0c/e3/90e9651924c430b885468b56b3d597cabf6d72be4b24a0acd1fa0e12af67/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d", size = 23914, upload-time = "2024-10-18T15:21:36.231Z" }, + { url = "https://files.pythonhosted.org/packages/66/8c/6c7cf61f95d63bb866db39085150df1f2a5bd3335298f14a66b48e92659c/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe", size = 23796, upload-time = "2024-10-18T15:21:37.073Z" }, + { url = "https://files.pythonhosted.org/packages/bb/35/cbe9238ec3f47ac9a7c8b3df7a808e7cb50fe149dc7039f5f454b3fba218/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5", size = 25473, upload-time = "2024-10-18T15:21:37.932Z" }, + { url = "https://files.pythonhosted.org/packages/e6/32/7621a4382488aa283cc05e8984a9c219abad3bca087be9ec77e89939ded9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a", size = 24114, upload-time = "2024-10-18T15:21:39.799Z" }, + { url = "https://files.pythonhosted.org/packages/0d/80/0985960e4b89922cb5a0bac0ed39c5b96cbc1a536a99f30e8c220a996ed9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9", size = 24098, upload-time = "2024-10-18T15:21:40.813Z" }, + { url = "https://files.pythonhosted.org/packages/82/78/fedb03c7d5380df2427038ec8d973587e90561b2d90cd472ce9254cf348b/MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6", size = 15208, upload-time = "2024-10-18T15:21:41.814Z" }, + { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739, upload-time = "2024-10-18T15:21:42.784Z" }, +] + +[[package]] +name = "maturin" +version = "1.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7c/b11b870fc4fd84de2099906314ce45488ae17be32ff5493519a6cddc518a/maturin-1.9.4.tar.gz", hash = "sha256:235163a0c99bc6f380fb8786c04fd14dcf6cd622ff295ea3de525015e6ac40cf", size = 213647, upload-time = "2025-08-27T11:37:57.079Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/90/0d99389eea1939116fca841cad0763600c8d3183a02a9478d066736c60e8/maturin-1.9.4-py3-none-linux_armv6l.whl", hash = "sha256:6ff37578e3f5fdbe685110d45f60af1f5a7dfce70a1e26dfe3810af66853ecae", size = 8276133, upload-time = "2025-08-27T11:37:23.325Z" }, + { url = "https://files.pythonhosted.org/packages/f4/ed/c8ec68b383e50f084bf1fa9605e62a90cd32a3f75d9894ed3a6e5d4cc5b3/maturin-1.9.4-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:f3837bb53611b2dafa1c090436c330f2d743ba305ef00d8801a371f4495e7e1b", size = 15994496, upload-time = "2025-08-27T11:37:27.092Z" }, + { url = "https://files.pythonhosted.org/packages/84/4e/401ff5f3cfc6b123364d4b94379bf910d7baee32c9c95b72784ff2329357/maturin-1.9.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4227d627d8e3bfe45877a8d65e9d8351a9d01434549f0da75d2c06a1b570de58", size = 8362228, upload-time = "2025-08-27T11:37:31.181Z" }, + { url = "https://files.pythonhosted.org/packages/51/8e/c56176dd360da9650c62b8a5ecfb85432cf011e97e46c186901e6996002e/maturin-1.9.4-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:1bb2aa0fa29032e9c5aac03ac400396ddea12cadef242f8967e9c8ef715313a1", size = 8271397, upload-time = "2025-08-27T11:37:33.672Z" }, + { url = "https://files.pythonhosted.org/packages/d2/46/001fcc5c6ad509874896418d6169a61acd619df5b724f99766308c44a99f/maturin-1.9.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:a0868d52934c8a5d1411b42367633fdb5cd5515bec47a534192282167448ec30", size = 8775625, upload-time = "2025-08-27T11:37:35.86Z" }, + { url = "https://files.pythonhosted.org/packages/b4/2e/26fa7574f01c19b7a74680fd70e5bae2e8c40fed9683d1752e765062cc2b/maturin-1.9.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:68b7b833b25741c0f553b78e8b9e095b31ae7c6611533b3c7b71f84c2cb8fc44", size = 8051117, upload-time = "2025-08-27T11:37:38.278Z" }, + { url = "https://files.pythonhosted.org/packages/73/ee/ca7308832d4f5b521c1aa176d9265f6f93e0bd1ad82a90fd9cd799f6b28c/maturin-1.9.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:08dc86312afee55af778af919818632e35d8d0464ccd79cb86700d9ea560ccd7", size = 8132122, upload-time = "2025-08-27T11:37:40.499Z" }, + { url = "https://files.pythonhosted.org/packages/45/e8/c623955da75e801a06942edf1fdc4e772a9e8fbc1ceebbdc85d59584dc10/maturin-1.9.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:ef20ffdd943078c4c3699c29fb2ed722bb6b4419efdade6642d1dbf248f94a70", size = 10586762, upload-time = "2025-08-27T11:37:42.718Z" }, + { url = "https://files.pythonhosted.org/packages/3c/4b/19ad558fdf54e151b1b4916ed45f1952ada96684ee6db64f9cd91cabec09/maturin-1.9.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:368e958468431dfeec80f75eea9639b4356d8c42428b0128444424b083fecfb0", size = 8926988, upload-time = "2025-08-27T11:37:45.492Z" }, + { url = "https://files.pythonhosted.org/packages/7e/27/153ad15eccae26921e8a01812da9f3b7f9013368f8f92c36853f2043b2a3/maturin-1.9.4-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:273f879214f63f79bfe851cd7d541f8150bdbfae5dfdc3c0c4d125d02d1f41b4", size = 8536758, upload-time = "2025-08-27T11:37:48.213Z" }, + { url = "https://files.pythonhosted.org/packages/43/e3/f304c3bdc3fba9adebe5348d4d2dd015f1152c0a9027aaf52cae0bb182c8/maturin-1.9.4-py3-none-win32.whl", hash = "sha256:ed2e54d132ace7e61829bd49709331007dd9a2cc78937f598aa76a4f69b6804d", size = 7265200, upload-time = "2025-08-27T11:37:50.881Z" }, + { url = "https://files.pythonhosted.org/packages/14/14/f86d0124bf1816b99005c058a1dbdca7cb5850d9cf4b09dcae07a1bc6201/maturin-1.9.4-py3-none-win_amd64.whl", hash = "sha256:8e450bb2c9afdf38a0059ee2e1ec2b17323f152b59c16f33eb9c74edaf1f9f79", size = 8237391, upload-time = "2025-08-27T11:37:53.23Z" }, + { url = "https://files.pythonhosted.org/packages/3f/25/8320fc2591e45b750c3ae71fa596b47aefa802d07d6abaaa719034a85160/maturin-1.9.4-py3-none-win_arm64.whl", hash = "sha256:7a6f980a9b67a5c13c844c268eabd855b54a6a765df4b4bb07d15a990572a4c9", size = 6988277, upload-time = "2025-08-27T11:37:55.429Z" }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + +[[package]] +name = "multidict" +version = "6.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/7f/0652e6ed47ab288e3756ea9c0df8b14950781184d4bd7883f4d87dd41245/multidict-6.6.4.tar.gz", hash = "sha256:d2d4e4787672911b48350df02ed3fa3fffdc2f2e8ca06dd6afdf34189b76a9dd", size = 101843, upload-time = "2025-08-11T12:08:48.217Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/6b/86f353088c1358e76fd30b0146947fddecee812703b604ee901e85cd2a80/multidict-6.6.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b8aa6f0bd8125ddd04a6593437bad6a7e70f300ff4180a531654aa2ab3f6d58f", size = 77054, upload-time = "2025-08-11T12:06:02.99Z" }, + { url = "https://files.pythonhosted.org/packages/19/5d/c01dc3d3788bb877bd7f5753ea6eb23c1beeca8044902a8f5bfb54430f63/multidict-6.6.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9e5853bbd7264baca42ffc53391b490d65fe62849bf2c690fa3f6273dbcd0cb", size = 44914, upload-time = "2025-08-11T12:06:05.264Z" }, + { url = "https://files.pythonhosted.org/packages/46/44/964dae19ea42f7d3e166474d8205f14bb811020e28bc423d46123ddda763/multidict-6.6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0af5f9dee472371e36d6ae38bde009bd8ce65ac7335f55dcc240379d7bed1495", size = 44601, upload-time = "2025-08-11T12:06:06.627Z" }, + { url = "https://files.pythonhosted.org/packages/31/20/0616348a1dfb36cb2ab33fc9521de1f27235a397bf3f59338e583afadd17/multidict-6.6.4-cp310-cp310-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:d24f351e4d759f5054b641c81e8291e5d122af0fca5c72454ff77f7cbe492de8", size = 224821, upload-time = "2025-08-11T12:06:08.06Z" }, + { url = "https://files.pythonhosted.org/packages/14/26/5d8923c69c110ff51861af05bd27ca6783011b96725d59ccae6d9daeb627/multidict-6.6.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:db6a3810eec08280a172a6cd541ff4a5f6a97b161d93ec94e6c4018917deb6b7", size = 242608, upload-time = "2025-08-11T12:06:09.697Z" }, + { url = "https://files.pythonhosted.org/packages/5c/cc/e2ad3ba9459aa34fa65cf1f82a5c4a820a2ce615aacfb5143b8817f76504/multidict-6.6.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a1b20a9d56b2d81e2ff52ecc0670d583eaabaa55f402e8d16dd062373dbbe796", size = 222324, upload-time = "2025-08-11T12:06:10.905Z" }, + { url = "https://files.pythonhosted.org/packages/19/db/4ed0f65701afbc2cb0c140d2d02928bb0fe38dd044af76e58ad7c54fd21f/multidict-6.6.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8c9854df0eaa610a23494c32a6f44a3a550fb398b6b51a56e8c6b9b3689578db", size = 253234, upload-time = "2025-08-11T12:06:12.658Z" }, + { url = "https://files.pythonhosted.org/packages/94/c1/5160c9813269e39ae14b73debb907bfaaa1beee1762da8c4fb95df4764ed/multidict-6.6.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4bb7627fd7a968f41905a4d6343b0d63244a0623f006e9ed989fa2b78f4438a0", size = 251613, upload-time = "2025-08-11T12:06:13.97Z" }, + { url = "https://files.pythonhosted.org/packages/05/a9/48d1bd111fc2f8fb98b2ed7f9a115c55a9355358432a19f53c0b74d8425d/multidict-6.6.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:caebafea30ed049c57c673d0b36238b1748683be2593965614d7b0e99125c877", size = 241649, upload-time = "2025-08-11T12:06:15.204Z" }, + { url = "https://files.pythonhosted.org/packages/85/2a/f7d743df0019408768af8a70d2037546a2be7b81fbb65f040d76caafd4c5/multidict-6.6.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ad887a8250eb47d3ab083d2f98db7f48098d13d42eb7a3b67d8a5c795f224ace", size = 239238, upload-time = "2025-08-11T12:06:16.467Z" }, + { url = "https://files.pythonhosted.org/packages/cb/b8/4f4bb13323c2d647323f7919201493cf48ebe7ded971717bfb0f1a79b6bf/multidict-6.6.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:ed8358ae7d94ffb7c397cecb62cbac9578a83ecefc1eba27b9090ee910e2efb6", size = 233517, upload-time = "2025-08-11T12:06:18.107Z" }, + { url = "https://files.pythonhosted.org/packages/33/29/4293c26029ebfbba4f574febd2ed01b6f619cfa0d2e344217d53eef34192/multidict-6.6.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ecab51ad2462197a4c000b6d5701fc8585b80eecb90583635d7e327b7b6923eb", size = 243122, upload-time = "2025-08-11T12:06:19.361Z" }, + { url = "https://files.pythonhosted.org/packages/20/60/a1c53628168aa22447bfde3a8730096ac28086704a0d8c590f3b63388d0c/multidict-6.6.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c5c97aa666cf70e667dfa5af945424ba1329af5dd988a437efeb3a09430389fb", size = 248992, upload-time = "2025-08-11T12:06:20.661Z" }, + { url = "https://files.pythonhosted.org/packages/a3/3b/55443a0c372f33cae5d9ec37a6a973802884fa0ab3586659b197cf8cc5e9/multidict-6.6.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:9a950b7cf54099c1209f455ac5970b1ea81410f2af60ed9eb3c3f14f0bfcf987", size = 243708, upload-time = "2025-08-11T12:06:21.891Z" }, + { url = "https://files.pythonhosted.org/packages/7c/60/a18c6900086769312560b2626b18e8cca22d9e85b1186ba77f4755b11266/multidict-6.6.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:163c7ea522ea9365a8a57832dea7618e6cbdc3cd75f8c627663587459a4e328f", size = 237498, upload-time = "2025-08-11T12:06:23.206Z" }, + { url = "https://files.pythonhosted.org/packages/11/3d/8bdd8bcaff2951ce2affccca107a404925a2beafedd5aef0b5e4a71120a6/multidict-6.6.4-cp310-cp310-win32.whl", hash = "sha256:17d2cbbfa6ff20821396b25890f155f40c986f9cfbce5667759696d83504954f", size = 41415, upload-time = "2025-08-11T12:06:24.77Z" }, + { url = "https://files.pythonhosted.org/packages/c0/53/cab1ad80356a4cd1b685a254b680167059b433b573e53872fab245e9fc95/multidict-6.6.4-cp310-cp310-win_amd64.whl", hash = "sha256:ce9a40fbe52e57e7edf20113a4eaddfacac0561a0879734e636aa6d4bb5e3fb0", size = 46046, upload-time = "2025-08-11T12:06:25.893Z" }, + { url = "https://files.pythonhosted.org/packages/cf/9a/874212b6f5c1c2d870d0a7adc5bb4cfe9b0624fa15cdf5cf757c0f5087ae/multidict-6.6.4-cp310-cp310-win_arm64.whl", hash = "sha256:01d0959807a451fe9fdd4da3e139cb5b77f7328baf2140feeaf233e1d777b729", size = 43147, upload-time = "2025-08-11T12:06:27.534Z" }, + { url = "https://files.pythonhosted.org/packages/6b/7f/90a7f01e2d005d6653c689039977f6856718c75c5579445effb7e60923d1/multidict-6.6.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c7a0e9b561e6460484318a7612e725df1145d46b0ef57c6b9866441bf6e27e0c", size = 76472, upload-time = "2025-08-11T12:06:29.006Z" }, + { url = "https://files.pythonhosted.org/packages/54/a3/bed07bc9e2bb302ce752f1dabc69e884cd6a676da44fb0e501b246031fdd/multidict-6.6.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6bf2f10f70acc7a2446965ffbc726e5fc0b272c97a90b485857e5c70022213eb", size = 44634, upload-time = "2025-08-11T12:06:30.374Z" }, + { url = "https://files.pythonhosted.org/packages/a7/4b/ceeb4f8f33cf81277da464307afeaf164fb0297947642585884f5cad4f28/multidict-6.6.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:66247d72ed62d5dd29752ffc1d3b88f135c6a8de8b5f63b7c14e973ef5bda19e", size = 44282, upload-time = "2025-08-11T12:06:31.958Z" }, + { url = "https://files.pythonhosted.org/packages/03/35/436a5da8702b06866189b69f655ffdb8f70796252a8772a77815f1812679/multidict-6.6.4-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:105245cc6b76f51e408451a844a54e6823bbd5a490ebfe5bdfc79798511ceded", size = 229696, upload-time = "2025-08-11T12:06:33.087Z" }, + { url = "https://files.pythonhosted.org/packages/b6/0e/915160be8fecf1fca35f790c08fb74ca684d752fcba62c11daaf3d92c216/multidict-6.6.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cbbc54e58b34c3bae389ef00046be0961f30fef7cb0dd9c7756aee376a4f7683", size = 246665, upload-time = "2025-08-11T12:06:34.448Z" }, + { url = "https://files.pythonhosted.org/packages/08/ee/2f464330acd83f77dcc346f0b1a0eaae10230291450887f96b204b8ac4d3/multidict-6.6.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:56c6b3652f945c9bc3ac6c8178cd93132b8d82dd581fcbc3a00676c51302bc1a", size = 225485, upload-time = "2025-08-11T12:06:35.672Z" }, + { url = "https://files.pythonhosted.org/packages/71/cc/9a117f828b4d7fbaec6adeed2204f211e9caf0a012692a1ee32169f846ae/multidict-6.6.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b95494daf857602eccf4c18ca33337dd2be705bccdb6dddbfc9d513e6addb9d9", size = 257318, upload-time = "2025-08-11T12:06:36.98Z" }, + { url = "https://files.pythonhosted.org/packages/25/77/62752d3dbd70e27fdd68e86626c1ae6bccfebe2bb1f84ae226363e112f5a/multidict-6.6.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e5b1413361cef15340ab9dc61523e653d25723e82d488ef7d60a12878227ed50", size = 254689, upload-time = "2025-08-11T12:06:38.233Z" }, + { url = "https://files.pythonhosted.org/packages/00/6e/fac58b1072a6fc59af5e7acb245e8754d3e1f97f4f808a6559951f72a0d4/multidict-6.6.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e167bf899c3d724f9662ef00b4f7fef87a19c22b2fead198a6f68b263618df52", size = 246709, upload-time = "2025-08-11T12:06:39.517Z" }, + { url = "https://files.pythonhosted.org/packages/01/ef/4698d6842ef5e797c6db7744b0081e36fb5de3d00002cc4c58071097fac3/multidict-6.6.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aaea28ba20a9026dfa77f4b80369e51cb767c61e33a2d4043399c67bd95fb7c6", size = 243185, upload-time = "2025-08-11T12:06:40.796Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c9/d82e95ae1d6e4ef396934e9b0e942dfc428775f9554acf04393cce66b157/multidict-6.6.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:8c91cdb30809a96d9ecf442ec9bc45e8cfaa0f7f8bdf534e082c2443a196727e", size = 237838, upload-time = "2025-08-11T12:06:42.595Z" }, + { url = "https://files.pythonhosted.org/packages/57/cf/f94af5c36baaa75d44fab9f02e2a6bcfa0cd90acb44d4976a80960759dbc/multidict-6.6.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1a0ccbfe93ca114c5d65a2471d52d8829e56d467c97b0e341cf5ee45410033b3", size = 246368, upload-time = "2025-08-11T12:06:44.304Z" }, + { url = "https://files.pythonhosted.org/packages/4a/fe/29f23460c3d995f6a4b678cb2e9730e7277231b981f0b234702f0177818a/multidict-6.6.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:55624b3f321d84c403cb7d8e6e982f41ae233d85f85db54ba6286f7295dc8a9c", size = 253339, upload-time = "2025-08-11T12:06:45.597Z" }, + { url = "https://files.pythonhosted.org/packages/29/b6/fd59449204426187b82bf8a75f629310f68c6adc9559dc922d5abe34797b/multidict-6.6.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:4a1fb393a2c9d202cb766c76208bd7945bc194eba8ac920ce98c6e458f0b524b", size = 246933, upload-time = "2025-08-11T12:06:46.841Z" }, + { url = "https://files.pythonhosted.org/packages/19/52/d5d6b344f176a5ac3606f7a61fb44dc746e04550e1a13834dff722b8d7d6/multidict-6.6.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:43868297a5759a845fa3a483fb4392973a95fb1de891605a3728130c52b8f40f", size = 242225, upload-time = "2025-08-11T12:06:48.588Z" }, + { url = "https://files.pythonhosted.org/packages/ec/d3/5b2281ed89ff4d5318d82478a2a2450fcdfc3300da48ff15c1778280ad26/multidict-6.6.4-cp311-cp311-win32.whl", hash = "sha256:ed3b94c5e362a8a84d69642dbeac615452e8af9b8eb825b7bc9f31a53a1051e2", size = 41306, upload-time = "2025-08-11T12:06:49.95Z" }, + { url = "https://files.pythonhosted.org/packages/74/7d/36b045c23a1ab98507aefd44fd8b264ee1dd5e5010543c6fccf82141ccef/multidict-6.6.4-cp311-cp311-win_amd64.whl", hash = "sha256:d8c112f7a90d8ca5d20213aa41eac690bb50a76da153e3afb3886418e61cb22e", size = 46029, upload-time = "2025-08-11T12:06:51.082Z" }, + { url = "https://files.pythonhosted.org/packages/0f/5e/553d67d24432c5cd52b49047f2d248821843743ee6d29a704594f656d182/multidict-6.6.4-cp311-cp311-win_arm64.whl", hash = "sha256:3bb0eae408fa1996d87247ca0d6a57b7fc1dcf83e8a5c47ab82c558c250d4adf", size = 43017, upload-time = "2025-08-11T12:06:52.243Z" }, + { url = "https://files.pythonhosted.org/packages/05/f6/512ffd8fd8b37fb2680e5ac35d788f1d71bbaf37789d21a820bdc441e565/multidict-6.6.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0ffb87be160942d56d7b87b0fdf098e81ed565add09eaa1294268c7f3caac4c8", size = 76516, upload-time = "2025-08-11T12:06:53.393Z" }, + { url = "https://files.pythonhosted.org/packages/99/58/45c3e75deb8855c36bd66cc1658007589662ba584dbf423d01df478dd1c5/multidict-6.6.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d191de6cbab2aff5de6c5723101705fd044b3e4c7cfd587a1929b5028b9714b3", size = 45394, upload-time = "2025-08-11T12:06:54.555Z" }, + { url = "https://files.pythonhosted.org/packages/fd/ca/e8c4472a93a26e4507c0b8e1f0762c0d8a32de1328ef72fd704ef9cc5447/multidict-6.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38a0956dd92d918ad5feff3db8fcb4a5eb7dba114da917e1a88475619781b57b", size = 43591, upload-time = "2025-08-11T12:06:55.672Z" }, + { url = "https://files.pythonhosted.org/packages/05/51/edf414f4df058574a7265034d04c935aa84a89e79ce90fcf4df211f47b16/multidict-6.6.4-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:6865f6d3b7900ae020b495d599fcf3765653bc927951c1abb959017f81ae8287", size = 237215, upload-time = "2025-08-11T12:06:57.213Z" }, + { url = "https://files.pythonhosted.org/packages/c8/45/8b3d6dbad8cf3252553cc41abea09ad527b33ce47a5e199072620b296902/multidict-6.6.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a2088c126b6f72db6c9212ad827d0ba088c01d951cee25e758c450da732c138", size = 258299, upload-time = "2025-08-11T12:06:58.946Z" }, + { url = "https://files.pythonhosted.org/packages/3c/e8/8ca2e9a9f5a435fc6db40438a55730a4bf4956b554e487fa1b9ae920f825/multidict-6.6.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0f37bed7319b848097085d7d48116f545985db988e2256b2e6f00563a3416ee6", size = 242357, upload-time = "2025-08-11T12:07:00.301Z" }, + { url = "https://files.pythonhosted.org/packages/0f/84/80c77c99df05a75c28490b2af8f7cba2a12621186e0a8b0865d8e745c104/multidict-6.6.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:01368e3c94032ba6ca0b78e7ccb099643466cf24f8dc8eefcfdc0571d56e58f9", size = 268369, upload-time = "2025-08-11T12:07:01.638Z" }, + { url = "https://files.pythonhosted.org/packages/0d/e9/920bfa46c27b05fb3e1ad85121fd49f441492dca2449c5bcfe42e4565d8a/multidict-6.6.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8fe323540c255db0bffee79ad7f048c909f2ab0edb87a597e1c17da6a54e493c", size = 269341, upload-time = "2025-08-11T12:07:02.943Z" }, + { url = "https://files.pythonhosted.org/packages/af/65/753a2d8b05daf496f4a9c367fe844e90a1b2cac78e2be2c844200d10cc4c/multidict-6.6.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8eb3025f17b0a4c3cd08cda49acf312a19ad6e8a4edd9dbd591e6506d999402", size = 256100, upload-time = "2025-08-11T12:07:04.564Z" }, + { url = "https://files.pythonhosted.org/packages/09/54/655be13ae324212bf0bc15d665a4e34844f34c206f78801be42f7a0a8aaa/multidict-6.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bbc14f0365534d35a06970d6a83478b249752e922d662dc24d489af1aa0d1be7", size = 253584, upload-time = "2025-08-11T12:07:05.914Z" }, + { url = "https://files.pythonhosted.org/packages/5c/74/ab2039ecc05264b5cec73eb018ce417af3ebb384ae9c0e9ed42cb33f8151/multidict-6.6.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:75aa52fba2d96bf972e85451b99d8e19cc37ce26fd016f6d4aa60da9ab2b005f", size = 251018, upload-time = "2025-08-11T12:07:08.301Z" }, + { url = "https://files.pythonhosted.org/packages/af/0a/ccbb244ac848e56c6427f2392741c06302bbfba49c0042f1eb3c5b606497/multidict-6.6.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4fefd4a815e362d4f011919d97d7b4a1e566f1dde83dc4ad8cfb5b41de1df68d", size = 251477, upload-time = "2025-08-11T12:07:10.248Z" }, + { url = "https://files.pythonhosted.org/packages/0e/b0/0ed49bba775b135937f52fe13922bc64a7eaf0a3ead84a36e8e4e446e096/multidict-6.6.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:db9801fe021f59a5b375ab778973127ca0ac52429a26e2fd86aa9508f4d26eb7", size = 263575, upload-time = "2025-08-11T12:07:11.928Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d9/7fb85a85e14de2e44dfb6a24f03c41e2af8697a6df83daddb0e9b7569f73/multidict-6.6.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a650629970fa21ac1fb06ba25dabfc5b8a2054fcbf6ae97c758aa956b8dba802", size = 259649, upload-time = "2025-08-11T12:07:13.244Z" }, + { url = "https://files.pythonhosted.org/packages/03/9e/b3a459bcf9b6e74fa461a5222a10ff9b544cb1cd52fd482fb1b75ecda2a2/multidict-6.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:452ff5da78d4720d7516a3a2abd804957532dd69296cb77319c193e3ffb87e24", size = 251505, upload-time = "2025-08-11T12:07:14.57Z" }, + { url = "https://files.pythonhosted.org/packages/86/a2/8022f78f041dfe6d71e364001a5cf987c30edfc83c8a5fb7a3f0974cff39/multidict-6.6.4-cp312-cp312-win32.whl", hash = "sha256:8c2fcb12136530ed19572bbba61b407f655e3953ba669b96a35036a11a485793", size = 41888, upload-time = "2025-08-11T12:07:15.904Z" }, + { url = "https://files.pythonhosted.org/packages/c7/eb/d88b1780d43a56db2cba24289fa744a9d216c1a8546a0dc3956563fd53ea/multidict-6.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:047d9425860a8c9544fed1b9584f0c8bcd31bcde9568b047c5e567a1025ecd6e", size = 46072, upload-time = "2025-08-11T12:07:17.045Z" }, + { url = "https://files.pythonhosted.org/packages/9f/16/b929320bf5750e2d9d4931835a4c638a19d2494a5b519caaaa7492ebe105/multidict-6.6.4-cp312-cp312-win_arm64.whl", hash = "sha256:14754eb72feaa1e8ae528468f24250dd997b8e2188c3d2f593f9eba259e4b364", size = 43222, upload-time = "2025-08-11T12:07:18.328Z" }, + { url = "https://files.pythonhosted.org/packages/3a/5d/e1db626f64f60008320aab00fbe4f23fc3300d75892a3381275b3d284580/multidict-6.6.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:f46a6e8597f9bd71b31cc708195d42b634c8527fecbcf93febf1052cacc1f16e", size = 75848, upload-time = "2025-08-11T12:07:19.912Z" }, + { url = "https://files.pythonhosted.org/packages/4c/aa/8b6f548d839b6c13887253af4e29c939af22a18591bfb5d0ee6f1931dae8/multidict-6.6.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:22e38b2bc176c5eb9c0a0e379f9d188ae4cd8b28c0f53b52bce7ab0a9e534657", size = 45060, upload-time = "2025-08-11T12:07:21.163Z" }, + { url = "https://files.pythonhosted.org/packages/eb/c6/f5e97e5d99a729bc2aa58eb3ebfa9f1e56a9b517cc38c60537c81834a73f/multidict-6.6.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5df8afd26f162da59e218ac0eefaa01b01b2e6cd606cffa46608f699539246da", size = 43269, upload-time = "2025-08-11T12:07:22.392Z" }, + { url = "https://files.pythonhosted.org/packages/dc/31/d54eb0c62516776f36fe67f84a732f97e0b0e12f98d5685bebcc6d396910/multidict-6.6.4-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:49517449b58d043023720aa58e62b2f74ce9b28f740a0b5d33971149553d72aa", size = 237158, upload-time = "2025-08-11T12:07:23.636Z" }, + { url = "https://files.pythonhosted.org/packages/c4/1c/8a10c1c25b23156e63b12165a929d8eb49a6ed769fdbefb06e6f07c1e50d/multidict-6.6.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae9408439537c5afdca05edd128a63f56a62680f4b3c234301055d7a2000220f", size = 257076, upload-time = "2025-08-11T12:07:25.049Z" }, + { url = "https://files.pythonhosted.org/packages/ad/86/90e20b5771d6805a119e483fd3d1e8393e745a11511aebca41f0da38c3e2/multidict-6.6.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:87a32d20759dc52a9e850fe1061b6e41ab28e2998d44168a8a341b99ded1dba0", size = 240694, upload-time = "2025-08-11T12:07:26.458Z" }, + { url = "https://files.pythonhosted.org/packages/e7/49/484d3e6b535bc0555b52a0a26ba86e4d8d03fd5587d4936dc59ba7583221/multidict-6.6.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:52e3c8d43cdfff587ceedce9deb25e6ae77daba560b626e97a56ddcad3756879", size = 266350, upload-time = "2025-08-11T12:07:27.94Z" }, + { url = "https://files.pythonhosted.org/packages/bf/b4/aa4c5c379b11895083d50021e229e90c408d7d875471cb3abf721e4670d6/multidict-6.6.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ad8850921d3a8d8ff6fbef790e773cecfc260bbfa0566998980d3fa8f520bc4a", size = 267250, upload-time = "2025-08-11T12:07:29.303Z" }, + { url = "https://files.pythonhosted.org/packages/80/e5/5e22c5bf96a64bdd43518b1834c6d95a4922cc2066b7d8e467dae9b6cee6/multidict-6.6.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:497a2954adc25c08daff36f795077f63ad33e13f19bfff7736e72c785391534f", size = 254900, upload-time = "2025-08-11T12:07:30.764Z" }, + { url = "https://files.pythonhosted.org/packages/17/38/58b27fed927c07035abc02befacab42491e7388ca105e087e6e0215ead64/multidict-6.6.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:024ce601f92d780ca1617ad4be5ac15b501cc2414970ffa2bb2bbc2bd5a68fa5", size = 252355, upload-time = "2025-08-11T12:07:32.205Z" }, + { url = "https://files.pythonhosted.org/packages/d0/a1/dad75d23a90c29c02b5d6f3d7c10ab36c3197613be5d07ec49c7791e186c/multidict-6.6.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:a693fc5ed9bdd1c9e898013e0da4dcc640de7963a371c0bd458e50e046bf6438", size = 250061, upload-time = "2025-08-11T12:07:33.623Z" }, + { url = "https://files.pythonhosted.org/packages/b8/1a/ac2216b61c7f116edab6dc3378cca6c70dc019c9a457ff0d754067c58b20/multidict-6.6.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:190766dac95aab54cae5b152a56520fd99298f32a1266d66d27fdd1b5ac00f4e", size = 249675, upload-time = "2025-08-11T12:07:34.958Z" }, + { url = "https://files.pythonhosted.org/packages/d4/79/1916af833b800d13883e452e8e0977c065c4ee3ab7a26941fbfdebc11895/multidict-6.6.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:34d8f2a5ffdceab9dcd97c7a016deb2308531d5f0fced2bb0c9e1df45b3363d7", size = 261247, upload-time = "2025-08-11T12:07:36.588Z" }, + { url = "https://files.pythonhosted.org/packages/c5/65/d1f84fe08ac44a5fc7391cbc20a7cedc433ea616b266284413fd86062f8c/multidict-6.6.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:59e8d40ab1f5a8597abcef00d04845155a5693b5da00d2c93dbe88f2050f2812", size = 257960, upload-time = "2025-08-11T12:07:39.735Z" }, + { url = "https://files.pythonhosted.org/packages/13/b5/29ec78057d377b195ac2c5248c773703a6b602e132a763e20ec0457e7440/multidict-6.6.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:467fe64138cfac771f0e949b938c2e1ada2b5af22f39692aa9258715e9ea613a", size = 250078, upload-time = "2025-08-11T12:07:41.525Z" }, + { url = "https://files.pythonhosted.org/packages/c4/0e/7e79d38f70a872cae32e29b0d77024bef7834b0afb406ddae6558d9e2414/multidict-6.6.4-cp313-cp313-win32.whl", hash = "sha256:14616a30fe6d0a48d0a48d1a633ab3b8bec4cf293aac65f32ed116f620adfd69", size = 41708, upload-time = "2025-08-11T12:07:43.405Z" }, + { url = "https://files.pythonhosted.org/packages/9d/34/746696dffff742e97cd6a23da953e55d0ea51fa601fa2ff387b3edcfaa2c/multidict-6.6.4-cp313-cp313-win_amd64.whl", hash = "sha256:40cd05eaeb39e2bc8939451f033e57feaa2ac99e07dbca8afe2be450a4a3b6cf", size = 45912, upload-time = "2025-08-11T12:07:45.082Z" }, + { url = "https://files.pythonhosted.org/packages/c7/87/3bac136181e271e29170d8d71929cdeddeb77f3e8b6a0c08da3a8e9da114/multidict-6.6.4-cp313-cp313-win_arm64.whl", hash = "sha256:f6eb37d511bfae9e13e82cb4d1af36b91150466f24d9b2b8a9785816deb16605", size = 43076, upload-time = "2025-08-11T12:07:46.746Z" }, + { url = "https://files.pythonhosted.org/packages/64/94/0a8e63e36c049b571c9ae41ee301ada29c3fee9643d9c2548d7d558a1d99/multidict-6.6.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:6c84378acd4f37d1b507dfa0d459b449e2321b3ba5f2338f9b085cf7a7ba95eb", size = 82812, upload-time = "2025-08-11T12:07:48.402Z" }, + { url = "https://files.pythonhosted.org/packages/25/1a/be8e369dfcd260d2070a67e65dd3990dd635cbd735b98da31e00ea84cd4e/multidict-6.6.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0e0558693063c75f3d952abf645c78f3c5dfdd825a41d8c4d8156fc0b0da6e7e", size = 48313, upload-time = "2025-08-11T12:07:49.679Z" }, + { url = "https://files.pythonhosted.org/packages/26/5a/dd4ade298674b2f9a7b06a32c94ffbc0497354df8285f27317c66433ce3b/multidict-6.6.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3f8e2384cb83ebd23fd07e9eada8ba64afc4c759cd94817433ab8c81ee4b403f", size = 46777, upload-time = "2025-08-11T12:07:51.318Z" }, + { url = "https://files.pythonhosted.org/packages/89/db/98aa28bc7e071bfba611ac2ae803c24e96dd3a452b4118c587d3d872c64c/multidict-6.6.4-cp313-cp313t-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:f996b87b420995a9174b2a7c1a8daf7db4750be6848b03eb5e639674f7963773", size = 229321, upload-time = "2025-08-11T12:07:52.965Z" }, + { url = "https://files.pythonhosted.org/packages/c7/bc/01ddda2a73dd9d167bd85d0e8ef4293836a8f82b786c63fb1a429bc3e678/multidict-6.6.4-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc356250cffd6e78416cf5b40dc6a74f1edf3be8e834cf8862d9ed5265cf9b0e", size = 249954, upload-time = "2025-08-11T12:07:54.423Z" }, + { url = "https://files.pythonhosted.org/packages/06/78/6b7c0f020f9aa0acf66d0ab4eb9f08375bac9a50ff5e3edb1c4ccd59eafc/multidict-6.6.4-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:dadf95aa862714ea468a49ad1e09fe00fcc9ec67d122f6596a8d40caf6cec7d0", size = 228612, upload-time = "2025-08-11T12:07:55.914Z" }, + { url = "https://files.pythonhosted.org/packages/00/44/3faa416f89b2d5d76e9d447296a81521e1c832ad6e40b92f990697b43192/multidict-6.6.4-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7dd57515bebffd8ebd714d101d4c434063322e4fe24042e90ced41f18b6d3395", size = 257528, upload-time = "2025-08-11T12:07:57.371Z" }, + { url = "https://files.pythonhosted.org/packages/05/5f/77c03b89af0fcb16f018f668207768191fb9dcfb5e3361a5e706a11db2c9/multidict-6.6.4-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:967af5f238ebc2eb1da4e77af5492219fbd9b4b812347da39a7b5f5c72c0fa45", size = 256329, upload-time = "2025-08-11T12:07:58.844Z" }, + { url = "https://files.pythonhosted.org/packages/cf/e9/ed750a2a9afb4f8dc6f13dc5b67b514832101b95714f1211cd42e0aafc26/multidict-6.6.4-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a4c6875c37aae9794308ec43e3530e4aa0d36579ce38d89979bbf89582002bb", size = 247928, upload-time = "2025-08-11T12:08:01.037Z" }, + { url = "https://files.pythonhosted.org/packages/1f/b5/e0571bc13cda277db7e6e8a532791d4403dacc9850006cb66d2556e649c0/multidict-6.6.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:7f683a551e92bdb7fac545b9c6f9fa2aebdeefa61d607510b3533286fcab67f5", size = 245228, upload-time = "2025-08-11T12:08:02.96Z" }, + { url = "https://files.pythonhosted.org/packages/f3/a3/69a84b0eccb9824491f06368f5b86e72e4af54c3067c37c39099b6687109/multidict-6.6.4-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:3ba5aaf600edaf2a868a391779f7a85d93bed147854925f34edd24cc70a3e141", size = 235869, upload-time = "2025-08-11T12:08:04.746Z" }, + { url = "https://files.pythonhosted.org/packages/a9/9d/28802e8f9121a6a0804fa009debf4e753d0a59969ea9f70be5f5fdfcb18f/multidict-6.6.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:580b643b7fd2c295d83cad90d78419081f53fd532d1f1eb67ceb7060f61cff0d", size = 243446, upload-time = "2025-08-11T12:08:06.332Z" }, + { url = "https://files.pythonhosted.org/packages/38/ea/6c98add069b4878c1d66428a5f5149ddb6d32b1f9836a826ac764b9940be/multidict-6.6.4-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:37b7187197da6af3ee0b044dbc9625afd0c885f2800815b228a0e70f9a7f473d", size = 252299, upload-time = "2025-08-11T12:08:07.931Z" }, + { url = "https://files.pythonhosted.org/packages/3a/09/8fe02d204473e14c0af3affd50af9078839dfca1742f025cca765435d6b4/multidict-6.6.4-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e1b93790ed0bc26feb72e2f08299691ceb6da5e9e14a0d13cc74f1869af327a0", size = 246926, upload-time = "2025-08-11T12:08:09.467Z" }, + { url = "https://files.pythonhosted.org/packages/37/3d/7b1e10d774a6df5175ecd3c92bff069e77bed9ec2a927fdd4ff5fe182f67/multidict-6.6.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a506a77ddee1efcca81ecbeae27ade3e09cdf21a8ae854d766c2bb4f14053f92", size = 243383, upload-time = "2025-08-11T12:08:10.981Z" }, + { url = "https://files.pythonhosted.org/packages/50/b0/a6fae46071b645ae98786ab738447de1ef53742eaad949f27e960864bb49/multidict-6.6.4-cp313-cp313t-win32.whl", hash = "sha256:f93b2b2279883d1d0a9e1bd01f312d6fc315c5e4c1f09e112e4736e2f650bc4e", size = 47775, upload-time = "2025-08-11T12:08:12.439Z" }, + { url = "https://files.pythonhosted.org/packages/b2/0a/2436550b1520091af0600dff547913cb2d66fbac27a8c33bc1b1bccd8d98/multidict-6.6.4-cp313-cp313t-win_amd64.whl", hash = "sha256:6d46a180acdf6e87cc41dc15d8f5c2986e1e8739dc25dbb7dac826731ef381a4", size = 53100, upload-time = "2025-08-11T12:08:13.823Z" }, + { url = "https://files.pythonhosted.org/packages/97/ea/43ac51faff934086db9c072a94d327d71b7d8b40cd5dcb47311330929ef0/multidict-6.6.4-cp313-cp313t-win_arm64.whl", hash = "sha256:756989334015e3335d087a27331659820d53ba432befdef6a718398b0a8493ad", size = 45501, upload-time = "2025-08-11T12:08:15.173Z" }, + { url = "https://files.pythonhosted.org/packages/fd/69/b547032297c7e63ba2af494edba695d781af8a0c6e89e4d06cf848b21d80/multidict-6.6.4-py3-none-any.whl", hash = "sha256:27d8f8e125c07cb954e54d75d04905a9bba8a439c1d84aca94949d4d03d8601c", size = 12313, upload-time = "2025-08-11T12:08:46.891Z" }, +] + +[[package]] +name = "multiprocess" +version = "0.70.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/76/6e712a2623d146d314f17598df5de7224c85c0060ef63fd95cc15a25b3fa/multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee", size = 134980, upload-time = "2024-01-28T18:52:15.731Z" }, + { url = "https://files.pythonhosted.org/packages/0f/ab/1e6e8009e380e22254ff539ebe117861e5bdb3bff1fc977920972237c6c7/multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec", size = 134982, upload-time = "2024-01-28T18:52:17.783Z" }, + { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" }, + { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" }, + { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" }, + { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" }, + { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, +] + +[[package]] +name = "nanochat" +version = "0.1.0" +source = { editable = "." } +dependencies = [ + { name = "datasets" }, + { name = "fastapi" }, + { name = "files-to-prompt" }, + { name = "numpy" }, + { name = "psutil" }, + { name = "regex" }, + { name = "tiktoken" }, + { name = "tokenizers" }, + { name = "torch" }, + { name = "uvicorn" }, + { name = "wandb" }, +] + +[package.dev-dependencies] +dev = [ + { name = "maturin" }, + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [ + { name = "datasets", specifier = ">=4.0.0" }, + { name = "fastapi", specifier = ">=0.117.1" }, + { name = "files-to-prompt", specifier = ">=0.6" }, + { name = "numpy", specifier = "==1.26.4" }, + { name = "psutil", specifier = ">=7.1.0" }, + { name = "regex", specifier = ">=2025.9.1" }, + { name = "tiktoken", specifier = ">=0.11.0" }, + { name = "tokenizers", specifier = ">=0.22.0" }, + { name = "torch", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "uvicorn", specifier = ">=0.36.0" }, + { name = "wandb", specifier = ">=0.21.3" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "maturin", specifier = ">=1.9.4" }, + { name = "pytest", specifier = ">=8.0.0" }, +] + +[[package]] +name = "networkx" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform != 'linux'", +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" }, +] + +[[package]] +name = "networkx" +version = "3.5" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'linux'", + "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'linux'", +] +sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, +] + +[[package]] +name = "numpy" +version = "1.26.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129, upload-time = "2024-02-06T00:26:44.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/94/ace0fdea5241a27d13543ee117cbc65868e82213fb31a8eb7fe9ff23f313/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", size = 20631468, upload-time = "2024-02-05T23:48:01.194Z" }, + { url = "https://files.pythonhosted.org/packages/20/f7/b24208eba89f9d1b58c1668bc6c8c4fd472b20c45573cb767f59d49fb0f6/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a", size = 13966411, upload-time = "2024-02-05T23:48:29.038Z" }, + { url = "https://files.pythonhosted.org/packages/fc/a5/4beee6488160798683eed5bdb7eead455892c3b4e1f78d79d8d3f3b084ac/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4", size = 14219016, upload-time = "2024-02-05T23:48:54.098Z" }, + { url = "https://files.pythonhosted.org/packages/4b/d7/ecf66c1cd12dc28b4040b15ab4d17b773b87fa9d29ca16125de01adb36cd/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f", size = 18240889, upload-time = "2024-02-05T23:49:25.361Z" }, + { url = "https://files.pythonhosted.org/packages/24/03/6f229fe3187546435c4f6f89f6d26c129d4f5bed40552899fcf1f0bf9e50/numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a", size = 13876746, upload-time = "2024-02-05T23:49:51.983Z" }, + { url = "https://files.pythonhosted.org/packages/39/fe/39ada9b094f01f5a35486577c848fe274e374bbf8d8f472e1423a0bbd26d/numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2", size = 18078620, upload-time = "2024-02-05T23:50:22.515Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ef/6ad11d51197aad206a9ad2286dc1aac6a378059e06e8cf22cd08ed4f20dc/numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07", size = 5972659, upload-time = "2024-02-05T23:50:35.834Z" }, + { url = "https://files.pythonhosted.org/packages/19/77/538f202862b9183f54108557bfda67e17603fc560c384559e769321c9d92/numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5", size = 15808905, upload-time = "2024-02-05T23:51:03.701Z" }, + { url = "https://files.pythonhosted.org/packages/11/57/baae43d14fe163fa0e4c47f307b6b2511ab8d7d30177c491960504252053/numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71", size = 20630554, upload-time = "2024-02-05T23:51:50.149Z" }, + { url = "https://files.pythonhosted.org/packages/1a/2e/151484f49fd03944c4a3ad9c418ed193cfd02724e138ac8a9505d056c582/numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef", size = 13997127, upload-time = "2024-02-05T23:52:15.314Z" }, + { url = "https://files.pythonhosted.org/packages/79/ae/7e5b85136806f9dadf4878bf73cf223fe5c2636818ba3ab1c585d0403164/numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e", size = 14222994, upload-time = "2024-02-05T23:52:47.569Z" }, + { url = "https://files.pythonhosted.org/packages/3a/d0/edc009c27b406c4f9cbc79274d6e46d634d139075492ad055e3d68445925/numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5", size = 18252005, upload-time = "2024-02-05T23:53:15.637Z" }, + { url = "https://files.pythonhosted.org/packages/09/bf/2b1aaf8f525f2923ff6cfcf134ae5e750e279ac65ebf386c75a0cf6da06a/numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a", size = 13885297, upload-time = "2024-02-05T23:53:42.16Z" }, + { url = "https://files.pythonhosted.org/packages/df/a0/4e0f14d847cfc2a633a1c8621d00724f3206cfeddeb66d35698c4e2cf3d2/numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a", size = 18093567, upload-time = "2024-02-05T23:54:11.696Z" }, + { url = "https://files.pythonhosted.org/packages/d2/b7/a734c733286e10a7f1a8ad1ae8c90f2d33bf604a96548e0a4a3a6739b468/numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20", size = 5968812, upload-time = "2024-02-05T23:54:26.453Z" }, + { url = "https://files.pythonhosted.org/packages/3f/6b/5610004206cf7f8e7ad91c5a85a8c71b2f2f8051a0c0c4d5916b76d6cbb2/numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2", size = 15811913, upload-time = "2024-02-05T23:54:53.933Z" }, + { url = "https://files.pythonhosted.org/packages/95/12/8f2020a8e8b8383ac0177dc9570aad031a3beb12e38847f7129bacd96228/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218", size = 20335901, upload-time = "2024-02-05T23:55:32.801Z" }, + { url = "https://files.pythonhosted.org/packages/75/5b/ca6c8bd14007e5ca171c7c03102d17b4f4e0ceb53957e8c44343a9546dcc/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b", size = 13685868, upload-time = "2024-02-05T23:55:56.28Z" }, + { url = "https://files.pythonhosted.org/packages/79/f8/97f10e6755e2a7d027ca783f63044d5b1bc1ae7acb12afe6a9b4286eac17/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b", size = 13925109, upload-time = "2024-02-05T23:56:20.368Z" }, + { url = "https://files.pythonhosted.org/packages/0f/50/de23fde84e45f5c4fda2488c759b69990fd4512387a8632860f3ac9cd225/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed", size = 17950613, upload-time = "2024-02-05T23:56:56.054Z" }, + { url = "https://files.pythonhosted.org/packages/4c/0c/9c603826b6465e82591e05ca230dfc13376da512b25ccd0894709b054ed0/numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a", size = 13572172, upload-time = "2024-02-05T23:57:21.56Z" }, + { url = "https://files.pythonhosted.org/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", size = 17786643, upload-time = "2024-02-05T23:57:56.585Z" }, + { url = "https://files.pythonhosted.org/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", size = 5677803, upload-time = "2024-02-05T23:58:08.963Z" }, + { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754, upload-time = "2024-02-05T23:58:36.364Z" }, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pandas" +version = "2.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/8e/0e90233ac205ad182bd6b422532695d2b9414944a280488105d598c70023/pandas-2.3.2.tar.gz", hash = "sha256:ab7b58f8f82706890924ccdfb5f48002b83d2b5a3845976a9fb705d36c34dcdb", size = 4488684, upload-time = "2025-08-21T10:28:29.257Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/16/a8eeb70aad84ccbf14076793f90e0031eded63c1899aeae9fdfbf37881f4/pandas-2.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52bc29a946304c360561974c6542d1dd628ddafa69134a7131fdfd6a5d7a1a35", size = 11539648, upload-time = "2025-08-21T10:26:36.236Z" }, + { url = "https://files.pythonhosted.org/packages/47/f1/c5bdaea13bf3708554d93e948b7ea74121ce6e0d59537ca4c4f77731072b/pandas-2.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:220cc5c35ffaa764dd5bb17cf42df283b5cb7fdf49e10a7b053a06c9cb48ee2b", size = 10786923, upload-time = "2025-08-21T10:26:40.518Z" }, + { url = "https://files.pythonhosted.org/packages/bb/10/811fa01476d29ffed692e735825516ad0e56d925961819e6126b4ba32147/pandas-2.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42c05e15111221384019897df20c6fe893b2f697d03c811ee67ec9e0bb5a3424", size = 11726241, upload-time = "2025-08-21T10:26:43.175Z" }, + { url = "https://files.pythonhosted.org/packages/c4/6a/40b043b06e08df1ea1b6d20f0e0c2f2c4ec8c4f07d1c92948273d943a50b/pandas-2.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc03acc273c5515ab69f898df99d9d4f12c4d70dbfc24c3acc6203751d0804cf", size = 12349533, upload-time = "2025-08-21T10:26:46.611Z" }, + { url = "https://files.pythonhosted.org/packages/e2/ea/2e081a2302e41a9bca7056659fdd2b85ef94923723e41665b42d65afd347/pandas-2.3.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d25c20a03e8870f6339bcf67281b946bd20b86f1a544ebbebb87e66a8d642cba", size = 13202407, upload-time = "2025-08-21T10:26:49.068Z" }, + { url = "https://files.pythonhosted.org/packages/f4/12/7ff9f6a79e2ee8869dcf70741ef998b97ea20050fe25f83dc759764c1e32/pandas-2.3.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:21bb612d148bb5860b7eb2c10faacf1a810799245afd342cf297d7551513fbb6", size = 13837212, upload-time = "2025-08-21T10:26:51.832Z" }, + { url = "https://files.pythonhosted.org/packages/d8/df/5ab92fcd76455a632b3db34a746e1074d432c0cdbbd28d7cd1daba46a75d/pandas-2.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:b62d586eb25cb8cb70a5746a378fc3194cb7f11ea77170d59f889f5dfe3cec7a", size = 11338099, upload-time = "2025-08-21T10:26:54.382Z" }, + { url = "https://files.pythonhosted.org/packages/7a/59/f3e010879f118c2d400902d2d871c2226cef29b08c09fb8dc41111730400/pandas-2.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1333e9c299adcbb68ee89a9bb568fc3f20f9cbb419f1dd5225071e6cddb2a743", size = 11563308, upload-time = "2025-08-21T10:26:56.656Z" }, + { url = "https://files.pythonhosted.org/packages/38/18/48f10f1cc5c397af59571d638d211f494dba481f449c19adbd282aa8f4ca/pandas-2.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:76972bcbd7de8e91ad5f0ca884a9f2c477a2125354af624e022c49e5bd0dfff4", size = 10820319, upload-time = "2025-08-21T10:26:59.162Z" }, + { url = "https://files.pythonhosted.org/packages/95/3b/1e9b69632898b048e223834cd9702052bcf06b15e1ae716eda3196fb972e/pandas-2.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b98bdd7c456a05eef7cd21fd6b29e3ca243591fe531c62be94a2cc987efb5ac2", size = 11790097, upload-time = "2025-08-21T10:27:02.204Z" }, + { url = "https://files.pythonhosted.org/packages/8b/ef/0e2ffb30b1f7fbc9a588bd01e3c14a0d96854d09a887e15e30cc19961227/pandas-2.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d81573b3f7db40d020983f78721e9bfc425f411e616ef019a10ebf597aedb2e", size = 12397958, upload-time = "2025-08-21T10:27:05.409Z" }, + { url = "https://files.pythonhosted.org/packages/23/82/e6b85f0d92e9afb0e7f705a51d1399b79c7380c19687bfbf3d2837743249/pandas-2.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e190b738675a73b581736cc8ec71ae113d6c3768d0bd18bffa5b9a0927b0b6ea", size = 13225600, upload-time = "2025-08-21T10:27:07.791Z" }, + { url = "https://files.pythonhosted.org/packages/e8/f1/f682015893d9ed51611948bd83683670842286a8edd4f68c2c1c3b231eef/pandas-2.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c253828cb08f47488d60f43c5fc95114c771bbfff085da54bfc79cb4f9e3a372", size = 13879433, upload-time = "2025-08-21T10:27:10.347Z" }, + { url = "https://files.pythonhosted.org/packages/a7/e7/ae86261695b6c8a36d6a4c8d5f9b9ede8248510d689a2f379a18354b37d7/pandas-2.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:9467697b8083f9667b212633ad6aa4ab32436dcbaf4cd57325debb0ddef2012f", size = 11336557, upload-time = "2025-08-21T10:27:12.983Z" }, + { url = "https://files.pythonhosted.org/packages/ec/db/614c20fb7a85a14828edd23f1c02db58a30abf3ce76f38806155d160313c/pandas-2.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fbb977f802156e7a3f829e9d1d5398f6192375a3e2d1a9ee0803e35fe70a2b9", size = 11587652, upload-time = "2025-08-21T10:27:15.888Z" }, + { url = "https://files.pythonhosted.org/packages/99/b0/756e52f6582cade5e746f19bad0517ff27ba9c73404607c0306585c201b3/pandas-2.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b9b52693123dd234b7c985c68b709b0b009f4521000d0525f2b95c22f15944b", size = 10717686, upload-time = "2025-08-21T10:27:18.486Z" }, + { url = "https://files.pythonhosted.org/packages/37/4c/dd5ccc1e357abfeee8353123282de17997f90ff67855f86154e5a13b81e5/pandas-2.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bd281310d4f412733f319a5bc552f86d62cddc5f51d2e392c8787335c994175", size = 11278722, upload-time = "2025-08-21T10:27:21.149Z" }, + { url = "https://files.pythonhosted.org/packages/d3/a4/f7edcfa47e0a88cda0be8b068a5bae710bf264f867edfdf7b71584ace362/pandas-2.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96d31a6b4354e3b9b8a2c848af75d31da390657e3ac6f30c05c82068b9ed79b9", size = 11987803, upload-time = "2025-08-21T10:27:23.767Z" }, + { url = "https://files.pythonhosted.org/packages/f6/61/1bce4129f93ab66f1c68b7ed1c12bac6a70b1b56c5dab359c6bbcd480b52/pandas-2.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:df4df0b9d02bb873a106971bb85d448378ef14b86ba96f035f50bbd3688456b4", size = 12766345, upload-time = "2025-08-21T10:27:26.6Z" }, + { url = "https://files.pythonhosted.org/packages/8e/46/80d53de70fee835531da3a1dae827a1e76e77a43ad22a8cd0f8142b61587/pandas-2.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:213a5adf93d020b74327cb2c1b842884dbdd37f895f42dcc2f09d451d949f811", size = 13439314, upload-time = "2025-08-21T10:27:29.213Z" }, + { url = "https://files.pythonhosted.org/packages/28/30/8114832daff7489f179971dbc1d854109b7f4365a546e3ea75b6516cea95/pandas-2.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:8c13b81a9347eb8c7548f53fd9a4f08d4dfe996836543f805c987bafa03317ae", size = 10983326, upload-time = "2025-08-21T10:27:31.901Z" }, + { url = "https://files.pythonhosted.org/packages/27/64/a2f7bf678af502e16b472527735d168b22b7824e45a4d7e96a4fbb634b59/pandas-2.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0c6ecbac99a354a051ef21c5307601093cb9e0f4b1855984a084bfec9302699e", size = 11531061, upload-time = "2025-08-21T10:27:34.647Z" }, + { url = "https://files.pythonhosted.org/packages/54/4c/c3d21b2b7769ef2f4c2b9299fcadd601efa6729f1357a8dbce8dd949ed70/pandas-2.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c6f048aa0fd080d6a06cc7e7537c09b53be6642d330ac6f54a600c3ace857ee9", size = 10668666, upload-time = "2025-08-21T10:27:37.203Z" }, + { url = "https://files.pythonhosted.org/packages/50/e2/f775ba76ecfb3424d7f5862620841cf0edb592e9abd2d2a5387d305fe7a8/pandas-2.3.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0064187b80a5be6f2f9c9d6bdde29372468751dfa89f4211a3c5871854cfbf7a", size = 11332835, upload-time = "2025-08-21T10:27:40.188Z" }, + { url = "https://files.pythonhosted.org/packages/8f/52/0634adaace9be2d8cac9ef78f05c47f3a675882e068438b9d7ec7ef0c13f/pandas-2.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ac8c320bded4718b298281339c1a50fb00a6ba78cb2a63521c39bec95b0209b", size = 12057211, upload-time = "2025-08-21T10:27:43.117Z" }, + { url = "https://files.pythonhosted.org/packages/0b/9d/2df913f14b2deb9c748975fdb2491da1a78773debb25abbc7cbc67c6b549/pandas-2.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:114c2fe4f4328cf98ce5716d1532f3ab79c5919f95a9cfee81d9140064a2e4d6", size = 12749277, upload-time = "2025-08-21T10:27:45.474Z" }, + { url = "https://files.pythonhosted.org/packages/87/af/da1a2417026bd14d98c236dba88e39837182459d29dcfcea510b2ac9e8a1/pandas-2.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:48fa91c4dfb3b2b9bfdb5c24cd3567575f4e13f9636810462ffed8925352be5a", size = 13415256, upload-time = "2025-08-21T10:27:49.885Z" }, + { url = "https://files.pythonhosted.org/packages/22/3c/f2af1ce8840ef648584a6156489636b5692c162771918aa95707c165ad2b/pandas-2.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:12d039facec710f7ba305786837d0225a3444af7bbd9c15c32ca2d40d157ed8b", size = 10982579, upload-time = "2025-08-21T10:28:08.435Z" }, + { url = "https://files.pythonhosted.org/packages/f3/98/8df69c4097a6719e357dc249bf437b8efbde808038268e584421696cbddf/pandas-2.3.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:c624b615ce97864eb588779ed4046186f967374185c047070545253a52ab2d57", size = 12028163, upload-time = "2025-08-21T10:27:52.232Z" }, + { url = "https://files.pythonhosted.org/packages/0e/23/f95cbcbea319f349e10ff90db488b905c6883f03cbabd34f6b03cbc3c044/pandas-2.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:0cee69d583b9b128823d9514171cabb6861e09409af805b54459bd0c821a35c2", size = 11391860, upload-time = "2025-08-21T10:27:54.673Z" }, + { url = "https://files.pythonhosted.org/packages/ad/1b/6a984e98c4abee22058aa75bfb8eb90dce58cf8d7296f8bc56c14bc330b0/pandas-2.3.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2319656ed81124982900b4c37f0e0c58c015af9a7bbc62342ba5ad07ace82ba9", size = 11309830, upload-time = "2025-08-21T10:27:56.957Z" }, + { url = "https://files.pythonhosted.org/packages/15/d5/f0486090eb18dd8710bf60afeaf638ba6817047c0c8ae5c6a25598665609/pandas-2.3.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b37205ad6f00d52f16b6d09f406434ba928c1a1966e2771006a9033c736d30d2", size = 11883216, upload-time = "2025-08-21T10:27:59.302Z" }, + { url = "https://files.pythonhosted.org/packages/10/86/692050c119696da19e20245bbd650d8dfca6ceb577da027c3a73c62a047e/pandas-2.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:837248b4fc3a9b83b9c6214699a13f069dc13510a6a6d7f9ba33145d2841a012", size = 12699743, upload-time = "2025-08-21T10:28:02.447Z" }, + { url = "https://files.pythonhosted.org/packages/cd/d7/612123674d7b17cf345aad0a10289b2a384bff404e0463a83c4a3a59d205/pandas-2.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d2c3554bd31b731cd6490d94a28f3abb8dd770634a9e06eb6d2911b9827db370", size = 13186141, upload-time = "2025-08-21T10:28:05.377Z" }, +] + +[[package]] +name = "platformdirs" +version = "4.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/23/e8/21db9c9987b0e728855bd57bff6984f67952bea55d6f75e055c46b5383e8/platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf", size = 21634, upload-time = "2025-08-26T14:32:04.268Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/4b/2028861e724d3bd36227adfa20d3fd24c3fc6d52032f4a93c133be5d17ce/platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85", size = 18654, upload-time = "2025-08-26T14:32:02.735Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "propcache" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/16/43264e4a779dd8588c21a70f0709665ee8f611211bdd2c87d952cfa7c776/propcache-0.3.2.tar.gz", hash = "sha256:20d7d62e4e7ef05f221e0db2856b979540686342e7dd9973b815599c7057e168", size = 44139, upload-time = "2025-06-09T22:56:06.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/14/510deed325e262afeb8b360043c5d7c960da7d3ecd6d6f9496c9c56dc7f4/propcache-0.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:22d9962a358aedbb7a2e36187ff273adeaab9743373a272976d2e348d08c7770", size = 73178, upload-time = "2025-06-09T22:53:40.126Z" }, + { url = "https://files.pythonhosted.org/packages/cd/4e/ad52a7925ff01c1325653a730c7ec3175a23f948f08626a534133427dcff/propcache-0.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0d0fda578d1dc3f77b6b5a5dce3b9ad69a8250a891760a548df850a5e8da87f3", size = 43133, upload-time = "2025-06-09T22:53:41.965Z" }, + { url = "https://files.pythonhosted.org/packages/63/7c/e9399ba5da7780871db4eac178e9c2e204c23dd3e7d32df202092a1ed400/propcache-0.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3def3da3ac3ce41562d85db655d18ebac740cb3fa4367f11a52b3da9d03a5cc3", size = 43039, upload-time = "2025-06-09T22:53:43.268Z" }, + { url = "https://files.pythonhosted.org/packages/22/e1/58da211eb8fdc6fc854002387d38f415a6ca5f5c67c1315b204a5d3e9d7a/propcache-0.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bec58347a5a6cebf239daba9bda37dffec5b8d2ce004d9fe4edef3d2815137e", size = 201903, upload-time = "2025-06-09T22:53:44.872Z" }, + { url = "https://files.pythonhosted.org/packages/c4/0a/550ea0f52aac455cb90111c8bab995208443e46d925e51e2f6ebdf869525/propcache-0.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55ffda449a507e9fbd4aca1a7d9aa6753b07d6166140e5a18d2ac9bc49eac220", size = 213362, upload-time = "2025-06-09T22:53:46.707Z" }, + { url = "https://files.pythonhosted.org/packages/5a/af/9893b7d878deda9bb69fcf54600b247fba7317761b7db11fede6e0f28bd0/propcache-0.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64a67fb39229a8a8491dd42f864e5e263155e729c2e7ff723d6e25f596b1e8cb", size = 210525, upload-time = "2025-06-09T22:53:48.547Z" }, + { url = "https://files.pythonhosted.org/packages/7c/bb/38fd08b278ca85cde36d848091ad2b45954bc5f15cce494bb300b9285831/propcache-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9da1cf97b92b51253d5b68cf5a2b9e0dafca095e36b7f2da335e27dc6172a614", size = 198283, upload-time = "2025-06-09T22:53:50.067Z" }, + { url = "https://files.pythonhosted.org/packages/78/8c/9fe55bd01d362bafb413dfe508c48753111a1e269737fa143ba85693592c/propcache-0.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f559e127134b07425134b4065be45b166183fdcb433cb6c24c8e4149056ad50", size = 191872, upload-time = "2025-06-09T22:53:51.438Z" }, + { url = "https://files.pythonhosted.org/packages/54/14/4701c33852937a22584e08abb531d654c8bcf7948a8f87ad0a4822394147/propcache-0.3.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:aff2e4e06435d61f11a428360a932138d0ec288b0a31dd9bd78d200bd4a2b339", size = 199452, upload-time = "2025-06-09T22:53:53.229Z" }, + { url = "https://files.pythonhosted.org/packages/16/44/447f2253d859602095356007657ee535e0093215ea0b3d1d6a41d16e5201/propcache-0.3.2-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:4927842833830942a5d0a56e6f4839bc484785b8e1ce8d287359794818633ba0", size = 191567, upload-time = "2025-06-09T22:53:54.541Z" }, + { url = "https://files.pythonhosted.org/packages/f2/b3/e4756258749bb2d3b46defcff606a2f47410bab82be5824a67e84015b267/propcache-0.3.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:6107ddd08b02654a30fb8ad7a132021759d750a82578b94cd55ee2772b6ebea2", size = 193015, upload-time = "2025-06-09T22:53:56.44Z" }, + { url = "https://files.pythonhosted.org/packages/1e/df/e6d3c7574233164b6330b9fd697beeac402afd367280e6dc377bb99b43d9/propcache-0.3.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:70bd8b9cd6b519e12859c99f3fc9a93f375ebd22a50296c3a295028bea73b9e7", size = 204660, upload-time = "2025-06-09T22:53:57.839Z" }, + { url = "https://files.pythonhosted.org/packages/b2/53/e4d31dd5170b4a0e2e6b730f2385a96410633b4833dc25fe5dffd1f73294/propcache-0.3.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2183111651d710d3097338dd1893fcf09c9f54e27ff1a8795495a16a469cc90b", size = 206105, upload-time = "2025-06-09T22:53:59.638Z" }, + { url = "https://files.pythonhosted.org/packages/7f/fe/74d54cf9fbe2a20ff786e5f7afcfde446588f0cf15fb2daacfbc267b866c/propcache-0.3.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:fb075ad271405dcad8e2a7ffc9a750a3bf70e533bd86e89f0603e607b93aa64c", size = 196980, upload-time = "2025-06-09T22:54:01.071Z" }, + { url = "https://files.pythonhosted.org/packages/22/ec/c469c9d59dada8a7679625e0440b544fe72e99311a4679c279562051f6fc/propcache-0.3.2-cp310-cp310-win32.whl", hash = "sha256:404d70768080d3d3bdb41d0771037da19d8340d50b08e104ca0e7f9ce55fce70", size = 37679, upload-time = "2025-06-09T22:54:03.003Z" }, + { url = "https://files.pythonhosted.org/packages/38/35/07a471371ac89d418f8d0b699c75ea6dca2041fbda360823de21f6a9ce0a/propcache-0.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:7435d766f978b4ede777002e6b3b6641dd229cd1da8d3d3106a45770365f9ad9", size = 41459, upload-time = "2025-06-09T22:54:04.134Z" }, + { url = "https://files.pythonhosted.org/packages/80/8d/e8b436717ab9c2cfc23b116d2c297305aa4cd8339172a456d61ebf5669b8/propcache-0.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0b8d2f607bd8f80ddc04088bc2a037fdd17884a6fcadc47a96e334d72f3717be", size = 74207, upload-time = "2025-06-09T22:54:05.399Z" }, + { url = "https://files.pythonhosted.org/packages/d6/29/1e34000e9766d112171764b9fa3226fa0153ab565d0c242c70e9945318a7/propcache-0.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06766d8f34733416e2e34f46fea488ad5d60726bb9481d3cddf89a6fa2d9603f", size = 43648, upload-time = "2025-06-09T22:54:08.023Z" }, + { url = "https://files.pythonhosted.org/packages/46/92/1ad5af0df781e76988897da39b5f086c2bf0f028b7f9bd1f409bb05b6874/propcache-0.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2dc1f4a1df4fecf4e6f68013575ff4af84ef6f478fe5344317a65d38a8e6dc9", size = 43496, upload-time = "2025-06-09T22:54:09.228Z" }, + { url = "https://files.pythonhosted.org/packages/b3/ce/e96392460f9fb68461fabab3e095cb00c8ddf901205be4eae5ce246e5b7e/propcache-0.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be29c4f4810c5789cf10ddf6af80b041c724e629fa51e308a7a0fb19ed1ef7bf", size = 217288, upload-time = "2025-06-09T22:54:10.466Z" }, + { url = "https://files.pythonhosted.org/packages/c5/2a/866726ea345299f7ceefc861a5e782b045545ae6940851930a6adaf1fca6/propcache-0.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59d61f6970ecbd8ff2e9360304d5c8876a6abd4530cb752c06586849ac8a9dc9", size = 227456, upload-time = "2025-06-09T22:54:11.828Z" }, + { url = "https://files.pythonhosted.org/packages/de/03/07d992ccb6d930398689187e1b3c718339a1c06b8b145a8d9650e4726166/propcache-0.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:62180e0b8dbb6b004baec00a7983e4cc52f5ada9cd11f48c3528d8cfa7b96a66", size = 225429, upload-time = "2025-06-09T22:54:13.823Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e6/116ba39448753b1330f48ab8ba927dcd6cf0baea8a0ccbc512dfb49ba670/propcache-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c144ca294a204c470f18cf4c9d78887810d04a3e2fbb30eea903575a779159df", size = 213472, upload-time = "2025-06-09T22:54:15.232Z" }, + { url = "https://files.pythonhosted.org/packages/a6/85/f01f5d97e54e428885a5497ccf7f54404cbb4f906688a1690cd51bf597dc/propcache-0.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5c2a784234c28854878d68978265617aa6dc0780e53d44b4d67f3651a17a9a2", size = 204480, upload-time = "2025-06-09T22:54:17.104Z" }, + { url = "https://files.pythonhosted.org/packages/e3/79/7bf5ab9033b8b8194cc3f7cf1aaa0e9c3256320726f64a3e1f113a812dce/propcache-0.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5745bc7acdafa978ca1642891b82c19238eadc78ba2aaa293c6863b304e552d7", size = 214530, upload-time = "2025-06-09T22:54:18.512Z" }, + { url = "https://files.pythonhosted.org/packages/31/0b/bd3e0c00509b609317df4a18e6b05a450ef2d9a963e1d8bc9c9415d86f30/propcache-0.3.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:c0075bf773d66fa8c9d41f66cc132ecc75e5bb9dd7cce3cfd14adc5ca184cb95", size = 205230, upload-time = "2025-06-09T22:54:19.947Z" }, + { url = "https://files.pythonhosted.org/packages/7a/23/fae0ff9b54b0de4e819bbe559508da132d5683c32d84d0dc2ccce3563ed4/propcache-0.3.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5f57aa0847730daceff0497f417c9de353c575d8da3579162cc74ac294c5369e", size = 206754, upload-time = "2025-06-09T22:54:21.716Z" }, + { url = "https://files.pythonhosted.org/packages/b7/7f/ad6a3c22630aaa5f618b4dc3c3598974a72abb4c18e45a50b3cdd091eb2f/propcache-0.3.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:eef914c014bf72d18efb55619447e0aecd5fb7c2e3fa7441e2e5d6099bddff7e", size = 218430, upload-time = "2025-06-09T22:54:23.17Z" }, + { url = "https://files.pythonhosted.org/packages/5b/2c/ba4f1c0e8a4b4c75910742f0d333759d441f65a1c7f34683b4a74c0ee015/propcache-0.3.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:2a4092e8549031e82facf3decdbc0883755d5bbcc62d3aea9d9e185549936dcf", size = 223884, upload-time = "2025-06-09T22:54:25.539Z" }, + { url = "https://files.pythonhosted.org/packages/88/e4/ebe30fc399e98572019eee82ad0caf512401661985cbd3da5e3140ffa1b0/propcache-0.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:85871b050f174bc0bfb437efbdb68aaf860611953ed12418e4361bc9c392749e", size = 211480, upload-time = "2025-06-09T22:54:26.892Z" }, + { url = "https://files.pythonhosted.org/packages/96/0a/7d5260b914e01d1d0906f7f38af101f8d8ed0dc47426219eeaf05e8ea7c2/propcache-0.3.2-cp311-cp311-win32.whl", hash = "sha256:36c8d9b673ec57900c3554264e630d45980fd302458e4ac801802a7fd2ef7897", size = 37757, upload-time = "2025-06-09T22:54:28.241Z" }, + { url = "https://files.pythonhosted.org/packages/e1/2d/89fe4489a884bc0da0c3278c552bd4ffe06a1ace559db5ef02ef24ab446b/propcache-0.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:e53af8cb6a781b02d2ea079b5b853ba9430fcbe18a8e3ce647d5982a3ff69f39", size = 41500, upload-time = "2025-06-09T22:54:29.4Z" }, + { url = "https://files.pythonhosted.org/packages/a8/42/9ca01b0a6f48e81615dca4765a8f1dd2c057e0540f6116a27dc5ee01dfb6/propcache-0.3.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:8de106b6c84506b31c27168582cd3cb3000a6412c16df14a8628e5871ff83c10", size = 73674, upload-time = "2025-06-09T22:54:30.551Z" }, + { url = "https://files.pythonhosted.org/packages/af/6e/21293133beb550f9c901bbece755d582bfaf2176bee4774000bd4dd41884/propcache-0.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:28710b0d3975117239c76600ea351934ac7b5ff56e60953474342608dbbb6154", size = 43570, upload-time = "2025-06-09T22:54:32.296Z" }, + { url = "https://files.pythonhosted.org/packages/0c/c8/0393a0a3a2b8760eb3bde3c147f62b20044f0ddac81e9d6ed7318ec0d852/propcache-0.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce26862344bdf836650ed2487c3d724b00fbfec4233a1013f597b78c1cb73615", size = 43094, upload-time = "2025-06-09T22:54:33.929Z" }, + { url = "https://files.pythonhosted.org/packages/37/2c/489afe311a690399d04a3e03b069225670c1d489eb7b044a566511c1c498/propcache-0.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bca54bd347a253af2cf4544bbec232ab982f4868de0dd684246b67a51bc6b1db", size = 226958, upload-time = "2025-06-09T22:54:35.186Z" }, + { url = "https://files.pythonhosted.org/packages/9d/ca/63b520d2f3d418c968bf596839ae26cf7f87bead026b6192d4da6a08c467/propcache-0.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55780d5e9a2ddc59711d727226bb1ba83a22dd32f64ee15594b9392b1f544eb1", size = 234894, upload-time = "2025-06-09T22:54:36.708Z" }, + { url = "https://files.pythonhosted.org/packages/11/60/1d0ed6fff455a028d678df30cc28dcee7af77fa2b0e6962ce1df95c9a2a9/propcache-0.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:035e631be25d6975ed87ab23153db6a73426a48db688070d925aa27e996fe93c", size = 233672, upload-time = "2025-06-09T22:54:38.062Z" }, + { url = "https://files.pythonhosted.org/packages/37/7c/54fd5301ef38505ab235d98827207176a5c9b2aa61939b10a460ca53e123/propcache-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee6f22b6eaa39297c751d0e80c0d3a454f112f5c6481214fcf4c092074cecd67", size = 224395, upload-time = "2025-06-09T22:54:39.634Z" }, + { url = "https://files.pythonhosted.org/packages/ee/1a/89a40e0846f5de05fdc6779883bf46ba980e6df4d2ff8fb02643de126592/propcache-0.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ca3aee1aa955438c4dba34fc20a9f390e4c79967257d830f137bd5a8a32ed3b", size = 212510, upload-time = "2025-06-09T22:54:41.565Z" }, + { url = "https://files.pythonhosted.org/packages/5e/33/ca98368586c9566a6b8d5ef66e30484f8da84c0aac3f2d9aec6d31a11bd5/propcache-0.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7a4f30862869fa2b68380d677cc1c5fcf1e0f2b9ea0cf665812895c75d0ca3b8", size = 222949, upload-time = "2025-06-09T22:54:43.038Z" }, + { url = "https://files.pythonhosted.org/packages/ba/11/ace870d0aafe443b33b2f0b7efdb872b7c3abd505bfb4890716ad7865e9d/propcache-0.3.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b77ec3c257d7816d9f3700013639db7491a434644c906a2578a11daf13176251", size = 217258, upload-time = "2025-06-09T22:54:44.376Z" }, + { url = "https://files.pythonhosted.org/packages/5b/d2/86fd6f7adffcfc74b42c10a6b7db721d1d9ca1055c45d39a1a8f2a740a21/propcache-0.3.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:cab90ac9d3f14b2d5050928483d3d3b8fb6b4018893fc75710e6aa361ecb2474", size = 213036, upload-time = "2025-06-09T22:54:46.243Z" }, + { url = "https://files.pythonhosted.org/packages/07/94/2d7d1e328f45ff34a0a284cf5a2847013701e24c2a53117e7c280a4316b3/propcache-0.3.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0b504d29f3c47cf6b9e936c1852246c83d450e8e063d50562115a6be6d3a2535", size = 227684, upload-time = "2025-06-09T22:54:47.63Z" }, + { url = "https://files.pythonhosted.org/packages/b7/05/37ae63a0087677e90b1d14710e532ff104d44bc1efa3b3970fff99b891dc/propcache-0.3.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:ce2ac2675a6aa41ddb2a0c9cbff53780a617ac3d43e620f8fd77ba1c84dcfc06", size = 234562, upload-time = "2025-06-09T22:54:48.982Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7c/3f539fcae630408d0bd8bf3208b9a647ccad10976eda62402a80adf8fc34/propcache-0.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b4239611205294cc433845b914131b2a1f03500ff3c1ed093ed216b82621e1", size = 222142, upload-time = "2025-06-09T22:54:50.424Z" }, + { url = "https://files.pythonhosted.org/packages/7c/d2/34b9eac8c35f79f8a962546b3e97e9d4b990c420ee66ac8255d5d9611648/propcache-0.3.2-cp312-cp312-win32.whl", hash = "sha256:df4a81b9b53449ebc90cc4deefb052c1dd934ba85012aa912c7ea7b7e38b60c1", size = 37711, upload-time = "2025-06-09T22:54:52.072Z" }, + { url = "https://files.pythonhosted.org/packages/19/61/d582be5d226cf79071681d1b46b848d6cb03d7b70af7063e33a2787eaa03/propcache-0.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:7046e79b989d7fe457bb755844019e10f693752d169076138abf17f31380800c", size = 41479, upload-time = "2025-06-09T22:54:53.234Z" }, + { url = "https://files.pythonhosted.org/packages/dc/d1/8c747fafa558c603c4ca19d8e20b288aa0c7cda74e9402f50f31eb65267e/propcache-0.3.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ca592ed634a73ca002967458187109265e980422116c0a107cf93d81f95af945", size = 71286, upload-time = "2025-06-09T22:54:54.369Z" }, + { url = "https://files.pythonhosted.org/packages/61/99/d606cb7986b60d89c36de8a85d58764323b3a5ff07770a99d8e993b3fa73/propcache-0.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9ecb0aad4020e275652ba3975740f241bd12a61f1a784df044cf7477a02bc252", size = 42425, upload-time = "2025-06-09T22:54:55.642Z" }, + { url = "https://files.pythonhosted.org/packages/8c/96/ef98f91bbb42b79e9bb82bdd348b255eb9d65f14dbbe3b1594644c4073f7/propcache-0.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7f08f1cc28bd2eade7a8a3d2954ccc673bb02062e3e7da09bc75d843386b342f", size = 41846, upload-time = "2025-06-09T22:54:57.246Z" }, + { url = "https://files.pythonhosted.org/packages/5b/ad/3f0f9a705fb630d175146cd7b1d2bf5555c9beaed54e94132b21aac098a6/propcache-0.3.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1a342c834734edb4be5ecb1e9fb48cb64b1e2320fccbd8c54bf8da8f2a84c33", size = 208871, upload-time = "2025-06-09T22:54:58.975Z" }, + { url = "https://files.pythonhosted.org/packages/3a/38/2085cda93d2c8b6ec3e92af2c89489a36a5886b712a34ab25de9fbca7992/propcache-0.3.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a544caaae1ac73f1fecfae70ded3e93728831affebd017d53449e3ac052ac1e", size = 215720, upload-time = "2025-06-09T22:55:00.471Z" }, + { url = "https://files.pythonhosted.org/packages/61/c1/d72ea2dc83ac7f2c8e182786ab0fc2c7bd123a1ff9b7975bee671866fe5f/propcache-0.3.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310d11aa44635298397db47a3ebce7db99a4cc4b9bbdfcf6c98a60c8d5261cf1", size = 215203, upload-time = "2025-06-09T22:55:01.834Z" }, + { url = "https://files.pythonhosted.org/packages/af/81/b324c44ae60c56ef12007105f1460d5c304b0626ab0cc6b07c8f2a9aa0b8/propcache-0.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c1396592321ac83157ac03a2023aa6cc4a3cc3cfdecb71090054c09e5a7cce3", size = 206365, upload-time = "2025-06-09T22:55:03.199Z" }, + { url = "https://files.pythonhosted.org/packages/09/73/88549128bb89e66d2aff242488f62869014ae092db63ccea53c1cc75a81d/propcache-0.3.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cabf5b5902272565e78197edb682017d21cf3b550ba0460ee473753f28d23c1", size = 196016, upload-time = "2025-06-09T22:55:04.518Z" }, + { url = "https://files.pythonhosted.org/packages/b9/3f/3bdd14e737d145114a5eb83cb172903afba7242f67c5877f9909a20d948d/propcache-0.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0a2f2235ac46a7aa25bdeb03a9e7060f6ecbd213b1f9101c43b3090ffb971ef6", size = 205596, upload-time = "2025-06-09T22:55:05.942Z" }, + { url = "https://files.pythonhosted.org/packages/0f/ca/2f4aa819c357d3107c3763d7ef42c03980f9ed5c48c82e01e25945d437c1/propcache-0.3.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:92b69e12e34869a6970fd2f3da91669899994b47c98f5d430b781c26f1d9f387", size = 200977, upload-time = "2025-06-09T22:55:07.792Z" }, + { url = "https://files.pythonhosted.org/packages/cd/4a/e65276c7477533c59085251ae88505caf6831c0e85ff8b2e31ebcbb949b1/propcache-0.3.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:54e02207c79968ebbdffc169591009f4474dde3b4679e16634d34c9363ff56b4", size = 197220, upload-time = "2025-06-09T22:55:09.173Z" }, + { url = "https://files.pythonhosted.org/packages/7c/54/fc7152e517cf5578278b242396ce4d4b36795423988ef39bb8cd5bf274c8/propcache-0.3.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4adfb44cb588001f68c5466579d3f1157ca07f7504fc91ec87862e2b8e556b88", size = 210642, upload-time = "2025-06-09T22:55:10.62Z" }, + { url = "https://files.pythonhosted.org/packages/b9/80/abeb4a896d2767bf5f1ea7b92eb7be6a5330645bd7fb844049c0e4045d9d/propcache-0.3.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:fd3e6019dc1261cd0291ee8919dd91fbab7b169bb76aeef6c716833a3f65d206", size = 212789, upload-time = "2025-06-09T22:55:12.029Z" }, + { url = "https://files.pythonhosted.org/packages/b3/db/ea12a49aa7b2b6d68a5da8293dcf50068d48d088100ac016ad92a6a780e6/propcache-0.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4c181cad81158d71c41a2bce88edce078458e2dd5ffee7eddd6b05da85079f43", size = 205880, upload-time = "2025-06-09T22:55:13.45Z" }, + { url = "https://files.pythonhosted.org/packages/d1/e5/9076a0bbbfb65d1198007059c65639dfd56266cf8e477a9707e4b1999ff4/propcache-0.3.2-cp313-cp313-win32.whl", hash = "sha256:8a08154613f2249519e549de2330cf8e2071c2887309a7b07fb56098f5170a02", size = 37220, upload-time = "2025-06-09T22:55:15.284Z" }, + { url = "https://files.pythonhosted.org/packages/d3/f5/b369e026b09a26cd77aa88d8fffd69141d2ae00a2abaaf5380d2603f4b7f/propcache-0.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:e41671f1594fc4ab0a6dec1351864713cb3a279910ae8b58f884a88a0a632c05", size = 40678, upload-time = "2025-06-09T22:55:16.445Z" }, + { url = "https://files.pythonhosted.org/packages/a4/3a/6ece377b55544941a08d03581c7bc400a3c8cd3c2865900a68d5de79e21f/propcache-0.3.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:9a3cf035bbaf035f109987d9d55dc90e4b0e36e04bbbb95af3055ef17194057b", size = 76560, upload-time = "2025-06-09T22:55:17.598Z" }, + { url = "https://files.pythonhosted.org/packages/0c/da/64a2bb16418740fa634b0e9c3d29edff1db07f56d3546ca2d86ddf0305e1/propcache-0.3.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:156c03d07dc1323d8dacaa221fbe028c5c70d16709cdd63502778e6c3ccca1b0", size = 44676, upload-time = "2025-06-09T22:55:18.922Z" }, + { url = "https://files.pythonhosted.org/packages/36/7b/f025e06ea51cb72c52fb87e9b395cced02786610b60a3ed51da8af017170/propcache-0.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74413c0ba02ba86f55cf60d18daab219f7e531620c15f1e23d95563f505efe7e", size = 44701, upload-time = "2025-06-09T22:55:20.106Z" }, + { url = "https://files.pythonhosted.org/packages/a4/00/faa1b1b7c3b74fc277f8642f32a4c72ba1d7b2de36d7cdfb676db7f4303e/propcache-0.3.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f066b437bb3fa39c58ff97ab2ca351db465157d68ed0440abecb21715eb24b28", size = 276934, upload-time = "2025-06-09T22:55:21.5Z" }, + { url = "https://files.pythonhosted.org/packages/74/ab/935beb6f1756e0476a4d5938ff44bf0d13a055fed880caf93859b4f1baf4/propcache-0.3.2-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1304b085c83067914721e7e9d9917d41ad87696bf70f0bc7dee450e9c71ad0a", size = 278316, upload-time = "2025-06-09T22:55:22.918Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9d/994a5c1ce4389610838d1caec74bdf0e98b306c70314d46dbe4fcf21a3e2/propcache-0.3.2-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ab50cef01b372763a13333b4e54021bdcb291fc9a8e2ccb9c2df98be51bcde6c", size = 282619, upload-time = "2025-06-09T22:55:24.651Z" }, + { url = "https://files.pythonhosted.org/packages/2b/00/a10afce3d1ed0287cef2e09506d3be9822513f2c1e96457ee369adb9a6cd/propcache-0.3.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fad3b2a085ec259ad2c2842666b2a0a49dea8463579c606426128925af1ed725", size = 265896, upload-time = "2025-06-09T22:55:26.049Z" }, + { url = "https://files.pythonhosted.org/packages/2e/a8/2aa6716ffa566ca57c749edb909ad27884680887d68517e4be41b02299f3/propcache-0.3.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:261fa020c1c14deafd54c76b014956e2f86991af198c51139faf41c4d5e83892", size = 252111, upload-time = "2025-06-09T22:55:27.381Z" }, + { url = "https://files.pythonhosted.org/packages/36/4f/345ca9183b85ac29c8694b0941f7484bf419c7f0fea2d1e386b4f7893eed/propcache-0.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:46d7f8aa79c927e5f987ee3a80205c987717d3659f035c85cf0c3680526bdb44", size = 268334, upload-time = "2025-06-09T22:55:28.747Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ca/fcd54f78b59e3f97b3b9715501e3147f5340167733d27db423aa321e7148/propcache-0.3.2-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:6d8f3f0eebf73e3c0ff0e7853f68be638b4043c65a70517bb575eff54edd8dbe", size = 255026, upload-time = "2025-06-09T22:55:30.184Z" }, + { url = "https://files.pythonhosted.org/packages/8b/95/8e6a6bbbd78ac89c30c225210a5c687790e532ba4088afb8c0445b77ef37/propcache-0.3.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:03c89c1b14a5452cf15403e291c0ccd7751d5b9736ecb2c5bab977ad6c5bcd81", size = 250724, upload-time = "2025-06-09T22:55:31.646Z" }, + { url = "https://files.pythonhosted.org/packages/ee/b0/0dd03616142baba28e8b2d14ce5df6631b4673850a3d4f9c0f9dd714a404/propcache-0.3.2-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:0cc17efde71e12bbaad086d679ce575268d70bc123a5a71ea7ad76f70ba30bba", size = 268868, upload-time = "2025-06-09T22:55:33.209Z" }, + { url = "https://files.pythonhosted.org/packages/c5/98/2c12407a7e4fbacd94ddd32f3b1e3d5231e77c30ef7162b12a60e2dd5ce3/propcache-0.3.2-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:acdf05d00696bc0447e278bb53cb04ca72354e562cf88ea6f9107df8e7fd9770", size = 271322, upload-time = "2025-06-09T22:55:35.065Z" }, + { url = "https://files.pythonhosted.org/packages/35/91/9cb56efbb428b006bb85db28591e40b7736847b8331d43fe335acf95f6c8/propcache-0.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4445542398bd0b5d32df908031cb1b30d43ac848e20470a878b770ec2dcc6330", size = 265778, upload-time = "2025-06-09T22:55:36.45Z" }, + { url = "https://files.pythonhosted.org/packages/9a/4c/b0fe775a2bdd01e176b14b574be679d84fc83958335790f7c9a686c1f468/propcache-0.3.2-cp313-cp313t-win32.whl", hash = "sha256:f86e5d7cd03afb3a1db8e9f9f6eff15794e79e791350ac48a8c924e6f439f394", size = 41175, upload-time = "2025-06-09T22:55:38.436Z" }, + { url = "https://files.pythonhosted.org/packages/a4/ff/47f08595e3d9b5e149c150f88d9714574f1a7cbd89fe2817158a952674bf/propcache-0.3.2-cp313-cp313t-win_amd64.whl", hash = "sha256:9704bedf6e7cbe3c65eca4379a9b53ee6a83749f047808cbb5044d40d7d72198", size = 44857, upload-time = "2025-06-09T22:55:39.687Z" }, + { url = "https://files.pythonhosted.org/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" }, +] + +[[package]] +name = "protobuf" +version = "6.32.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/df/fb4a8eeea482eca989b51cffd274aac2ee24e825f0bf3cbce5281fa1567b/protobuf-6.32.0.tar.gz", hash = "sha256:a81439049127067fc49ec1d36e25c6ee1d1a2b7be930675f919258d03c04e7d2", size = 440614, upload-time = "2025-08-14T21:21:25.015Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/18/df8c87da2e47f4f1dcc5153a81cd6bca4e429803f4069a299e236e4dd510/protobuf-6.32.0-cp310-abi3-win32.whl", hash = "sha256:84f9e3c1ff6fb0308dbacb0950d8aa90694b0d0ee68e75719cb044b7078fe741", size = 424409, upload-time = "2025-08-14T21:21:12.366Z" }, + { url = "https://files.pythonhosted.org/packages/e1/59/0a820b7310f8139bd8d5a9388e6a38e1786d179d6f33998448609296c229/protobuf-6.32.0-cp310-abi3-win_amd64.whl", hash = "sha256:a8bdbb2f009cfc22a36d031f22a625a38b615b5e19e558a7b756b3279723e68e", size = 435735, upload-time = "2025-08-14T21:21:15.046Z" }, + { url = "https://files.pythonhosted.org/packages/cc/5b/0d421533c59c789e9c9894683efac582c06246bf24bb26b753b149bd88e4/protobuf-6.32.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d52691e5bee6c860fff9a1c86ad26a13afbeb4b168cd4445c922b7e2cf85aaf0", size = 426449, upload-time = "2025-08-14T21:21:16.687Z" }, + { url = "https://files.pythonhosted.org/packages/ec/7b/607764ebe6c7a23dcee06e054fd1de3d5841b7648a90fd6def9a3bb58c5e/protobuf-6.32.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:501fe6372fd1c8ea2a30b4d9be8f87955a64d6be9c88a973996cef5ef6f0abf1", size = 322869, upload-time = "2025-08-14T21:21:18.282Z" }, + { url = "https://files.pythonhosted.org/packages/40/01/2e730bd1c25392fc32e3268e02446f0d77cb51a2c3a8486b1798e34d5805/protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:75a2aab2bd1aeb1f5dc7c5f33bcb11d82ea8c055c9becbb41c26a8c43fd7092c", size = 322009, upload-time = "2025-08-14T21:21:19.893Z" }, + { url = "https://files.pythonhosted.org/packages/9c/f2/80ffc4677aac1bc3519b26bc7f7f5de7fce0ee2f7e36e59e27d8beb32dd1/protobuf-6.32.0-py3-none-any.whl", hash = "sha256:ba377e5b67b908c8f3072a57b63e2c6a4cbd18aea4ed98d2584350dbf46f2783", size = 169287, upload-time = "2025-08-14T21:21:23.515Z" }, +] + +[[package]] +name = "psutil" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/31/4723d756b59344b643542936e37a31d1d3204bcdc42a7daa8ee9eb06fb50/psutil-7.1.0.tar.gz", hash = "sha256:655708b3c069387c8b77b072fc429a57d0e214221d01c0a772df7dfedcb3bcd2", size = 497660, upload-time = "2025-09-17T20:14:52.902Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/62/ce4051019ee20ce0ed74432dd73a5bb087a6704284a470bb8adff69a0932/psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13", size = 245242, upload-time = "2025-09-17T20:14:56.126Z" }, + { url = "https://files.pythonhosted.org/packages/38/61/f76959fba841bf5b61123fbf4b650886dc4094c6858008b5bf73d9057216/psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5", size = 246682, upload-time = "2025-09-17T20:14:58.25Z" }, + { url = "https://files.pythonhosted.org/packages/88/7a/37c99d2e77ec30d63398ffa6a660450b8a62517cabe44b3e9bae97696e8d/psutil-7.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e4454970b32472ce7deaa45d045b34d3648ce478e26a04c7e858a0a6e75ff3", size = 287994, upload-time = "2025-09-17T20:14:59.901Z" }, + { url = "https://files.pythonhosted.org/packages/9d/de/04c8c61232f7244aa0a4b9a9fbd63a89d5aeaf94b2fc9d1d16e2faa5cbb0/psutil-7.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70e113920d51e89f212dd7be06219a9b88014e63a4cec69b684c327bc474e3", size = 291163, upload-time = "2025-09-17T20:15:01.481Z" }, + { url = "https://files.pythonhosted.org/packages/f4/58/c4f976234bf6d4737bc8c02a81192f045c307b72cf39c9e5c5a2d78927f6/psutil-7.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d4a113425c037300de3ac8b331637293da9be9713855c4fc9d2d97436d7259d", size = 293625, upload-time = "2025-09-17T20:15:04.492Z" }, + { url = "https://files.pythonhosted.org/packages/79/87/157c8e7959ec39ced1b11cc93c730c4fb7f9d408569a6c59dbd92ceb35db/psutil-7.1.0-cp37-abi3-win32.whl", hash = "sha256:09ad740870c8d219ed8daae0ad3b726d3bf9a028a198e7f3080f6a1888b99bca", size = 244812, upload-time = "2025-09-17T20:15:07.462Z" }, + { url = "https://files.pythonhosted.org/packages/bf/e9/b44c4f697276a7a95b8e94d0e320a7bf7f3318521b23de69035540b39838/psutil-7.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:57f5e987c36d3146c0dd2528cd42151cf96cd359b9d67cfff836995cc5df9a3d", size = 247965, upload-time = "2025-09-17T20:15:09.673Z" }, + { url = "https://files.pythonhosted.org/packages/26/65/1070a6e3c036f39142c2820c4b52e9243246fcfc3f96239ac84472ba361e/psutil-7.1.0-cp37-abi3-win_arm64.whl", hash = "sha256:6937cb68133e7c97b6cc9649a570c9a18ba0efebed46d8c5dae4c07fa1b67a07", size = 244971, upload-time = "2025-09-17T20:15:12.262Z" }, +] + +[[package]] +name = "pyarrow" +version = "21.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ef/c2/ea068b8f00905c06329a3dfcd40d0fcc2b7d0f2e355bdb25b65e0a0e4cd4/pyarrow-21.0.0.tar.gz", hash = "sha256:5051f2dccf0e283ff56335760cbc8622cf52264d67e359d5569541ac11b6d5bc", size = 1133487, upload-time = "2025-07-18T00:57:31.761Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/d9/110de31880016e2afc52d8580b397dbe47615defbf09ca8cf55f56c62165/pyarrow-21.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:e563271e2c5ff4d4a4cbeb2c83d5cf0d4938b891518e676025f7268c6fe5fe26", size = 31196837, upload-time = "2025-07-18T00:54:34.755Z" }, + { url = "https://files.pythonhosted.org/packages/df/5f/c1c1997613abf24fceb087e79432d24c19bc6f7259cab57c2c8e5e545fab/pyarrow-21.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fee33b0ca46f4c85443d6c450357101e47d53e6c3f008d658c27a2d020d44c79", size = 32659470, upload-time = "2025-07-18T00:54:38.329Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ed/b1589a777816ee33ba123ba1e4f8f02243a844fed0deec97bde9fb21a5cf/pyarrow-21.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:7be45519b830f7c24b21d630a31d48bcebfd5d4d7f9d3bdb49da9cdf6d764edb", size = 41055619, upload-time = "2025-07-18T00:54:42.172Z" }, + { url = "https://files.pythonhosted.org/packages/44/28/b6672962639e85dc0ac36f71ab3a8f5f38e01b51343d7aa372a6b56fa3f3/pyarrow-21.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:26bfd95f6bff443ceae63c65dc7e048670b7e98bc892210acba7e4995d3d4b51", size = 42733488, upload-time = "2025-07-18T00:54:47.132Z" }, + { url = "https://files.pythonhosted.org/packages/f8/cc/de02c3614874b9089c94eac093f90ca5dfa6d5afe45de3ba847fd950fdf1/pyarrow-21.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:bd04ec08f7f8bd113c55868bd3fc442a9db67c27af098c5f814a3091e71cc61a", size = 43329159, upload-time = "2025-07-18T00:54:51.686Z" }, + { url = "https://files.pythonhosted.org/packages/a6/3e/99473332ac40278f196e105ce30b79ab8affab12f6194802f2593d6b0be2/pyarrow-21.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9b0b14b49ac10654332a805aedfc0147fb3469cbf8ea951b3d040dab12372594", size = 45050567, upload-time = "2025-07-18T00:54:56.679Z" }, + { url = "https://files.pythonhosted.org/packages/7b/f5/c372ef60593d713e8bfbb7e0c743501605f0ad00719146dc075faf11172b/pyarrow-21.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:9d9f8bcb4c3be7738add259738abdeddc363de1b80e3310e04067aa1ca596634", size = 26217959, upload-time = "2025-07-18T00:55:00.482Z" }, + { url = "https://files.pythonhosted.org/packages/94/dc/80564a3071a57c20b7c32575e4a0120e8a330ef487c319b122942d665960/pyarrow-21.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c077f48aab61738c237802836fc3844f85409a46015635198761b0d6a688f87b", size = 31243234, upload-time = "2025-07-18T00:55:03.812Z" }, + { url = "https://files.pythonhosted.org/packages/ea/cc/3b51cb2db26fe535d14f74cab4c79b191ed9a8cd4cbba45e2379b5ca2746/pyarrow-21.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:689f448066781856237eca8d1975b98cace19b8dd2ab6145bf49475478bcaa10", size = 32714370, upload-time = "2025-07-18T00:55:07.495Z" }, + { url = "https://files.pythonhosted.org/packages/24/11/a4431f36d5ad7d83b87146f515c063e4d07ef0b7240876ddb885e6b44f2e/pyarrow-21.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:479ee41399fcddc46159a551705b89c05f11e8b8cb8e968f7fec64f62d91985e", size = 41135424, upload-time = "2025-07-18T00:55:11.461Z" }, + { url = "https://files.pythonhosted.org/packages/74/dc/035d54638fc5d2971cbf1e987ccd45f1091c83bcf747281cf6cc25e72c88/pyarrow-21.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:40ebfcb54a4f11bcde86bc586cbd0272bac0d516cfa539c799c2453768477569", size = 42823810, upload-time = "2025-07-18T00:55:16.301Z" }, + { url = "https://files.pythonhosted.org/packages/2e/3b/89fced102448a9e3e0d4dded1f37fa3ce4700f02cdb8665457fcc8015f5b/pyarrow-21.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8d58d8497814274d3d20214fbb24abcad2f7e351474357d552a8d53bce70c70e", size = 43391538, upload-time = "2025-07-18T00:55:23.82Z" }, + { url = "https://files.pythonhosted.org/packages/fb/bb/ea7f1bd08978d39debd3b23611c293f64a642557e8141c80635d501e6d53/pyarrow-21.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:585e7224f21124dd57836b1530ac8f2df2afc43c861d7bf3d58a4870c42ae36c", size = 45120056, upload-time = "2025-07-18T00:55:28.231Z" }, + { url = "https://files.pythonhosted.org/packages/6e/0b/77ea0600009842b30ceebc3337639a7380cd946061b620ac1a2f3cb541e2/pyarrow-21.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:555ca6935b2cbca2c0e932bedd853e9bc523098c39636de9ad4693b5b1df86d6", size = 26220568, upload-time = "2025-07-18T00:55:32.122Z" }, + { url = "https://files.pythonhosted.org/packages/ca/d4/d4f817b21aacc30195cf6a46ba041dd1be827efa4a623cc8bf39a1c2a0c0/pyarrow-21.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3a302f0e0963db37e0a24a70c56cf91a4faa0bca51c23812279ca2e23481fccd", size = 31160305, upload-time = "2025-07-18T00:55:35.373Z" }, + { url = "https://files.pythonhosted.org/packages/a2/9c/dcd38ce6e4b4d9a19e1d36914cb8e2b1da4e6003dd075474c4cfcdfe0601/pyarrow-21.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:b6b27cf01e243871390474a211a7922bfbe3bda21e39bc9160daf0da3fe48876", size = 32684264, upload-time = "2025-07-18T00:55:39.303Z" }, + { url = "https://files.pythonhosted.org/packages/4f/74/2a2d9f8d7a59b639523454bec12dba35ae3d0a07d8ab529dc0809f74b23c/pyarrow-21.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e72a8ec6b868e258a2cd2672d91f2860ad532d590ce94cdf7d5e7ec674ccf03d", size = 41108099, upload-time = "2025-07-18T00:55:42.889Z" }, + { url = "https://files.pythonhosted.org/packages/ad/90/2660332eeb31303c13b653ea566a9918484b6e4d6b9d2d46879a33ab0622/pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b7ae0bbdc8c6674259b25bef5d2a1d6af5d39d7200c819cf99e07f7dfef1c51e", size = 42829529, upload-time = "2025-07-18T00:55:47.069Z" }, + { url = "https://files.pythonhosted.org/packages/33/27/1a93a25c92717f6aa0fca06eb4700860577d016cd3ae51aad0e0488ac899/pyarrow-21.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:58c30a1729f82d201627c173d91bd431db88ea74dcaa3885855bc6203e433b82", size = 43367883, upload-time = "2025-07-18T00:55:53.069Z" }, + { url = "https://files.pythonhosted.org/packages/05/d9/4d09d919f35d599bc05c6950095e358c3e15148ead26292dfca1fb659b0c/pyarrow-21.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:072116f65604b822a7f22945a7a6e581cfa28e3454fdcc6939d4ff6090126623", size = 45133802, upload-time = "2025-07-18T00:55:57.714Z" }, + { url = "https://files.pythonhosted.org/packages/71/30/f3795b6e192c3ab881325ffe172e526499eb3780e306a15103a2764916a2/pyarrow-21.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:cf56ec8b0a5c8c9d7021d6fd754e688104f9ebebf1bf4449613c9531f5346a18", size = 26203175, upload-time = "2025-07-18T00:56:01.364Z" }, + { url = "https://files.pythonhosted.org/packages/16/ca/c7eaa8e62db8fb37ce942b1ea0c6d7abfe3786ca193957afa25e71b81b66/pyarrow-21.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e99310a4ebd4479bcd1964dff9e14af33746300cb014aa4a3781738ac63baf4a", size = 31154306, upload-time = "2025-07-18T00:56:04.42Z" }, + { url = "https://files.pythonhosted.org/packages/ce/e8/e87d9e3b2489302b3a1aea709aaca4b781c5252fcb812a17ab6275a9a484/pyarrow-21.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d2fe8e7f3ce329a71b7ddd7498b3cfac0eeb200c2789bd840234f0dc271a8efe", size = 32680622, upload-time = "2025-07-18T00:56:07.505Z" }, + { url = "https://files.pythonhosted.org/packages/84/52/79095d73a742aa0aba370c7942b1b655f598069489ab387fe47261a849e1/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:f522e5709379d72fb3da7785aa489ff0bb87448a9dc5a75f45763a795a089ebd", size = 41104094, upload-time = "2025-07-18T00:56:10.994Z" }, + { url = "https://files.pythonhosted.org/packages/89/4b/7782438b551dbb0468892a276b8c789b8bbdb25ea5c5eb27faadd753e037/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:69cbbdf0631396e9925e048cfa5bce4e8c3d3b41562bbd70c685a8eb53a91e61", size = 42825576, upload-time = "2025-07-18T00:56:15.569Z" }, + { url = "https://files.pythonhosted.org/packages/b3/62/0f29de6e0a1e33518dec92c65be0351d32d7ca351e51ec5f4f837a9aab91/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:731c7022587006b755d0bdb27626a1a3bb004bb56b11fb30d98b6c1b4718579d", size = 43368342, upload-time = "2025-07-18T00:56:19.531Z" }, + { url = "https://files.pythonhosted.org/packages/90/c7/0fa1f3f29cf75f339768cc698c8ad4ddd2481c1742e9741459911c9ac477/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dc56bc708f2d8ac71bd1dcb927e458c93cec10b98eb4120206a4091db7b67b99", size = 45131218, upload-time = "2025-07-18T00:56:23.347Z" }, + { url = "https://files.pythonhosted.org/packages/01/63/581f2076465e67b23bc5a37d4a2abff8362d389d29d8105832e82c9c811c/pyarrow-21.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:186aa00bca62139f75b7de8420f745f2af12941595bbbfa7ed3870ff63e25636", size = 26087551, upload-time = "2025-07-18T00:56:26.758Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ab/357d0d9648bb8241ee7348e564f2479d206ebe6e1c47ac5027c2e31ecd39/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:a7a102574faa3f421141a64c10216e078df467ab9576684d5cd696952546e2da", size = 31290064, upload-time = "2025-07-18T00:56:30.214Z" }, + { url = "https://files.pythonhosted.org/packages/3f/8a/5685d62a990e4cac2043fc76b4661bf38d06efed55cf45a334b455bd2759/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:1e005378c4a2c6db3ada3ad4c217b381f6c886f0a80d6a316fe586b90f77efd7", size = 32727837, upload-time = "2025-07-18T00:56:33.935Z" }, + { url = "https://files.pythonhosted.org/packages/fc/de/c0828ee09525c2bafefd3e736a248ebe764d07d0fd762d4f0929dbc516c9/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:65f8e85f79031449ec8706b74504a316805217b35b6099155dd7e227eef0d4b6", size = 41014158, upload-time = "2025-07-18T00:56:37.528Z" }, + { url = "https://files.pythonhosted.org/packages/6e/26/a2865c420c50b7a3748320b614f3484bfcde8347b2639b2b903b21ce6a72/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:3a81486adc665c7eb1a2bde0224cfca6ceaba344a82a971ef059678417880eb8", size = 42667885, upload-time = "2025-07-18T00:56:41.483Z" }, + { url = "https://files.pythonhosted.org/packages/0a/f9/4ee798dc902533159250fb4321267730bc0a107d8c6889e07c3add4fe3a5/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:fc0d2f88b81dcf3ccf9a6ae17f89183762c8a94a5bdcfa09e05cfe413acf0503", size = 43276625, upload-time = "2025-07-18T00:56:48.002Z" }, + { url = "https://files.pythonhosted.org/packages/5a/da/e02544d6997037a4b0d22d8e5f66bc9315c3671371a8b18c79ade1cefe14/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6299449adf89df38537837487a4f8d3bd91ec94354fdd2a7d30bc11c48ef6e79", size = 44951890, upload-time = "2025-07-18T00:56:52.568Z" }, + { url = "https://files.pythonhosted.org/packages/e5/4e/519c1bc1876625fe6b71e9a28287c43ec2f20f73c658b9ae1d485c0c206e/pyarrow-21.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:222c39e2c70113543982c6b34f3077962b44fca38c0bd9e68bb6781534425c10", size = 26371006, upload-time = "2025-07-18T00:56:56.379Z" }, +] + +[[package]] +name = "pydantic" +version = "2.11.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/dd/4325abf92c39ba8623b5af936ddb36ffcfe0beae70405d456ab1fb2f5b8c/pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db", size = 788350, upload-time = "2025-06-14T08:33:17.137Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/92/b31726561b5dae176c2d2c2dc43a9c5bfba5d32f96f8b4c0a600dd492447/pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8", size = 2028817, upload-time = "2025-04-23T18:30:43.919Z" }, + { url = "https://files.pythonhosted.org/packages/a3/44/3f0b95fafdaca04a483c4e685fe437c6891001bf3ce8b2fded82b9ea3aa1/pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d", size = 1861357, upload-time = "2025-04-23T18:30:46.372Z" }, + { url = "https://files.pythonhosted.org/packages/30/97/e8f13b55766234caae05372826e8e4b3b96e7b248be3157f53237682e43c/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d", size = 1898011, upload-time = "2025-04-23T18:30:47.591Z" }, + { url = "https://files.pythonhosted.org/packages/9b/a3/99c48cf7bafc991cc3ee66fd544c0aae8dc907b752f1dad2d79b1b5a471f/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572", size = 1982730, upload-time = "2025-04-23T18:30:49.328Z" }, + { url = "https://files.pythonhosted.org/packages/de/8e/a5b882ec4307010a840fb8b58bd9bf65d1840c92eae7534c7441709bf54b/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02", size = 2136178, upload-time = "2025-04-23T18:30:50.907Z" }, + { url = "https://files.pythonhosted.org/packages/e4/bb/71e35fc3ed05af6834e890edb75968e2802fe98778971ab5cba20a162315/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b", size = 2736462, upload-time = "2025-04-23T18:30:52.083Z" }, + { url = "https://files.pythonhosted.org/packages/31/0d/c8f7593e6bc7066289bbc366f2235701dcbebcd1ff0ef8e64f6f239fb47d/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2", size = 2005652, upload-time = "2025-04-23T18:30:53.389Z" }, + { url = "https://files.pythonhosted.org/packages/d2/7a/996d8bd75f3eda405e3dd219ff5ff0a283cd8e34add39d8ef9157e722867/pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a", size = 2113306, upload-time = "2025-04-23T18:30:54.661Z" }, + { url = "https://files.pythonhosted.org/packages/ff/84/daf2a6fb2db40ffda6578a7e8c5a6e9c8affb251a05c233ae37098118788/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac", size = 2073720, upload-time = "2025-04-23T18:30:56.11Z" }, + { url = "https://files.pythonhosted.org/packages/77/fb/2258da019f4825128445ae79456a5499c032b55849dbd5bed78c95ccf163/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a", size = 2244915, upload-time = "2025-04-23T18:30:57.501Z" }, + { url = "https://files.pythonhosted.org/packages/d8/7a/925ff73756031289468326e355b6fa8316960d0d65f8b5d6b3a3e7866de7/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b", size = 2241884, upload-time = "2025-04-23T18:30:58.867Z" }, + { url = "https://files.pythonhosted.org/packages/0b/b0/249ee6d2646f1cdadcb813805fe76265745c4010cf20a8eba7b0e639d9b2/pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22", size = 1910496, upload-time = "2025-04-23T18:31:00.078Z" }, + { url = "https://files.pythonhosted.org/packages/66/ff/172ba8f12a42d4b552917aa65d1f2328990d3ccfc01d5b7c943ec084299f/pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640", size = 1955019, upload-time = "2025-04-23T18:31:01.335Z" }, + { url = "https://files.pythonhosted.org/packages/3f/8d/71db63483d518cbbf290261a1fc2839d17ff89fce7089e08cad07ccfce67/pydantic_core-2.33.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4c5b0a576fb381edd6d27f0a85915c6daf2f8138dc5c267a57c08a62900758c7", size = 2028584, upload-time = "2025-04-23T18:31:03.106Z" }, + { url = "https://files.pythonhosted.org/packages/24/2f/3cfa7244ae292dd850989f328722d2aef313f74ffc471184dc509e1e4e5a/pydantic_core-2.33.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e799c050df38a639db758c617ec771fd8fb7a5f8eaaa4b27b101f266b216a246", size = 1855071, upload-time = "2025-04-23T18:31:04.621Z" }, + { url = "https://files.pythonhosted.org/packages/b3/d3/4ae42d33f5e3f50dd467761304be2fa0a9417fbf09735bc2cce003480f2a/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc46a01bf8d62f227d5ecee74178ffc448ff4e5197c756331f71efcc66dc980f", size = 1897823, upload-time = "2025-04-23T18:31:06.377Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f3/aa5976e8352b7695ff808599794b1fba2a9ae2ee954a3426855935799488/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a144d4f717285c6d9234a66778059f33a89096dfb9b39117663fd8413d582dcc", size = 1983792, upload-time = "2025-04-23T18:31:07.93Z" }, + { url = "https://files.pythonhosted.org/packages/d5/7a/cda9b5a23c552037717f2b2a5257e9b2bfe45e687386df9591eff7b46d28/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cf6373c21bc80b2e0dc88444f41ae60b2f070ed02095754eb5a01df12256de", size = 2136338, upload-time = "2025-04-23T18:31:09.283Z" }, + { url = "https://files.pythonhosted.org/packages/2b/9f/b8f9ec8dd1417eb9da784e91e1667d58a2a4a7b7b34cf4af765ef663a7e5/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dc625f4aa79713512d1976fe9f0bc99f706a9dee21dfd1810b4bbbf228d0e8a", size = 2730998, upload-time = "2025-04-23T18:31:11.7Z" }, + { url = "https://files.pythonhosted.org/packages/47/bc/cd720e078576bdb8255d5032c5d63ee5c0bf4b7173dd955185a1d658c456/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b21b5549499972441da4758d662aeea93f1923f953e9cbaff14b8b9565aef", size = 2003200, upload-time = "2025-04-23T18:31:13.536Z" }, + { url = "https://files.pythonhosted.org/packages/ca/22/3602b895ee2cd29d11a2b349372446ae9727c32e78a94b3d588a40fdf187/pydantic_core-2.33.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bdc25f3681f7b78572699569514036afe3c243bc3059d3942624e936ec93450e", size = 2113890, upload-time = "2025-04-23T18:31:15.011Z" }, + { url = "https://files.pythonhosted.org/packages/ff/e6/e3c5908c03cf00d629eb38393a98fccc38ee0ce8ecce32f69fc7d7b558a7/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fe5b32187cbc0c862ee201ad66c30cf218e5ed468ec8dc1cf49dec66e160cc4d", size = 2073359, upload-time = "2025-04-23T18:31:16.393Z" }, + { url = "https://files.pythonhosted.org/packages/12/e7/6a36a07c59ebefc8777d1ffdaf5ae71b06b21952582e4b07eba88a421c79/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:bc7aee6f634a6f4a95676fcb5d6559a2c2a390330098dba5e5a5f28a2e4ada30", size = 2245883, upload-time = "2025-04-23T18:31:17.892Z" }, + { url = "https://files.pythonhosted.org/packages/16/3f/59b3187aaa6cc0c1e6616e8045b284de2b6a87b027cce2ffcea073adf1d2/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:235f45e5dbcccf6bd99f9f472858849f73d11120d76ea8707115415f8e5ebebf", size = 2241074, upload-time = "2025-04-23T18:31:19.205Z" }, + { url = "https://files.pythonhosted.org/packages/e0/ed/55532bb88f674d5d8f67ab121a2a13c385df382de2a1677f30ad385f7438/pydantic_core-2.33.2-cp311-cp311-win32.whl", hash = "sha256:6368900c2d3ef09b69cb0b913f9f8263b03786e5b2a387706c5afb66800efd51", size = 1910538, upload-time = "2025-04-23T18:31:20.541Z" }, + { url = "https://files.pythonhosted.org/packages/fe/1b/25b7cccd4519c0b23c2dd636ad39d381abf113085ce4f7bec2b0dc755eb1/pydantic_core-2.33.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e063337ef9e9820c77acc768546325ebe04ee38b08703244c1309cccc4f1bab", size = 1952909, upload-time = "2025-04-23T18:31:22.371Z" }, + { url = "https://files.pythonhosted.org/packages/49/a9/d809358e49126438055884c4366a1f6227f0f84f635a9014e2deb9b9de54/pydantic_core-2.33.2-cp311-cp311-win_arm64.whl", hash = "sha256:6b99022f1d19bc32a4c2a0d544fc9a76e3be90f0b3f4af413f87d38749300e65", size = 1897786, upload-time = "2025-04-23T18:31:24.161Z" }, + { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000, upload-time = "2025-04-23T18:31:25.863Z" }, + { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996, upload-time = "2025-04-23T18:31:27.341Z" }, + { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957, upload-time = "2025-04-23T18:31:28.956Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199, upload-time = "2025-04-23T18:31:31.025Z" }, + { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296, upload-time = "2025-04-23T18:31:32.514Z" }, + { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109, upload-time = "2025-04-23T18:31:33.958Z" }, + { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028, upload-time = "2025-04-23T18:31:39.095Z" }, + { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044, upload-time = "2025-04-23T18:31:41.034Z" }, + { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881, upload-time = "2025-04-23T18:31:42.757Z" }, + { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034, upload-time = "2025-04-23T18:31:44.304Z" }, + { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187, upload-time = "2025-04-23T18:31:45.891Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628, upload-time = "2025-04-23T18:31:47.819Z" }, + { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866, upload-time = "2025-04-23T18:31:49.635Z" }, + { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894, upload-time = "2025-04-23T18:31:51.609Z" }, + { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688, upload-time = "2025-04-23T18:31:53.175Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808, upload-time = "2025-04-23T18:31:54.79Z" }, + { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580, upload-time = "2025-04-23T18:31:57.393Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859, upload-time = "2025-04-23T18:31:59.065Z" }, + { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810, upload-time = "2025-04-23T18:32:00.78Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498, upload-time = "2025-04-23T18:32:02.418Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611, upload-time = "2025-04-23T18:32:04.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924, upload-time = "2025-04-23T18:32:06.129Z" }, + { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196, upload-time = "2025-04-23T18:32:08.178Z" }, + { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389, upload-time = "2025-04-23T18:32:10.242Z" }, + { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223, upload-time = "2025-04-23T18:32:12.382Z" }, + { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473, upload-time = "2025-04-23T18:32:14.034Z" }, + { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269, upload-time = "2025-04-23T18:32:15.783Z" }, + { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921, upload-time = "2025-04-23T18:32:18.473Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162, upload-time = "2025-04-23T18:32:20.188Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560, upload-time = "2025-04-23T18:32:22.354Z" }, + { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, + { url = "https://files.pythonhosted.org/packages/30/68/373d55e58b7e83ce371691f6eaa7175e3a24b956c44628eb25d7da007917/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa", size = 2023982, upload-time = "2025-04-23T18:32:53.14Z" }, + { url = "https://files.pythonhosted.org/packages/a4/16/145f54ac08c96a63d8ed6442f9dec17b2773d19920b627b18d4f10a061ea/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29", size = 1858412, upload-time = "2025-04-23T18:32:55.52Z" }, + { url = "https://files.pythonhosted.org/packages/41/b1/c6dc6c3e2de4516c0bb2c46f6a373b91b5660312342a0cf5826e38ad82fa/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d", size = 1892749, upload-time = "2025-04-23T18:32:57.546Z" }, + { url = "https://files.pythonhosted.org/packages/12/73/8cd57e20afba760b21b742106f9dbdfa6697f1570b189c7457a1af4cd8a0/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e", size = 2067527, upload-time = "2025-04-23T18:32:59.771Z" }, + { url = "https://files.pythonhosted.org/packages/e3/d5/0bb5d988cc019b3cba4a78f2d4b3854427fc47ee8ec8e9eaabf787da239c/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c", size = 2108225, upload-time = "2025-04-23T18:33:04.51Z" }, + { url = "https://files.pythonhosted.org/packages/f1/c5/00c02d1571913d496aabf146106ad8239dc132485ee22efe08085084ff7c/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec", size = 2069490, upload-time = "2025-04-23T18:33:06.391Z" }, + { url = "https://files.pythonhosted.org/packages/22/a8/dccc38768274d3ed3a59b5d06f59ccb845778687652daa71df0cab4040d7/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052", size = 2237525, upload-time = "2025-04-23T18:33:08.44Z" }, + { url = "https://files.pythonhosted.org/packages/d4/e7/4f98c0b125dda7cf7ccd14ba936218397b44f50a56dd8c16a3091df116c3/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c", size = 2238446, upload-time = "2025-04-23T18:33:10.313Z" }, + { url = "https://files.pythonhosted.org/packages/ce/91/2ec36480fdb0b783cd9ef6795753c1dea13882f2e68e73bce76ae8c21e6a/pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808", size = 2066678, upload-time = "2025-04-23T18:33:12.224Z" }, + { url = "https://files.pythonhosted.org/packages/7b/27/d4ae6487d73948d6f20dddcd94be4ea43e74349b56eba82e9bdee2d7494c/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:dd14041875d09cc0f9308e37a6f8b65f5585cf2598a53aa0123df8b129d481f8", size = 2025200, upload-time = "2025-04-23T18:33:14.199Z" }, + { url = "https://files.pythonhosted.org/packages/f1/b8/b3cb95375f05d33801024079b9392a5ab45267a63400bf1866e7ce0f0de4/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d87c561733f66531dced0da6e864f44ebf89a8fba55f31407b00c2f7f9449593", size = 1859123, upload-time = "2025-04-23T18:33:16.555Z" }, + { url = "https://files.pythonhosted.org/packages/05/bc/0d0b5adeda59a261cd30a1235a445bf55c7e46ae44aea28f7bd6ed46e091/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f82865531efd18d6e07a04a17331af02cb7a651583c418df8266f17a63c6612", size = 1892852, upload-time = "2025-04-23T18:33:18.513Z" }, + { url = "https://files.pythonhosted.org/packages/3e/11/d37bdebbda2e449cb3f519f6ce950927b56d62f0b84fd9cb9e372a26a3d5/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bfb5112df54209d820d7bf9317c7a6c9025ea52e49f46b6a2060104bba37de7", size = 2067484, upload-time = "2025-04-23T18:33:20.475Z" }, + { url = "https://files.pythonhosted.org/packages/8c/55/1f95f0a05ce72ecb02a8a8a1c3be0579bbc29b1d5ab68f1378b7bebc5057/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64632ff9d614e5eecfb495796ad51b0ed98c453e447a76bcbeeb69615079fc7e", size = 2108896, upload-time = "2025-04-23T18:33:22.501Z" }, + { url = "https://files.pythonhosted.org/packages/53/89/2b2de6c81fa131f423246a9109d7b2a375e83968ad0800d6e57d0574629b/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f889f7a40498cc077332c7ab6b4608d296d852182211787d4f3ee377aaae66e8", size = 2069475, upload-time = "2025-04-23T18:33:24.528Z" }, + { url = "https://files.pythonhosted.org/packages/b8/e9/1f7efbe20d0b2b10f6718944b5d8ece9152390904f29a78e68d4e7961159/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:de4b83bb311557e439b9e186f733f6c645b9417c84e2eb8203f3f820a4b988bf", size = 2239013, upload-time = "2025-04-23T18:33:26.621Z" }, + { url = "https://files.pythonhosted.org/packages/3c/b2/5309c905a93811524a49b4e031e9851a6b00ff0fb668794472ea7746b448/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82f68293f055f51b51ea42fafc74b6aad03e70e191799430b90c13d643059ebb", size = 2238715, upload-time = "2025-04-23T18:33:28.656Z" }, + { url = "https://files.pythonhosted.org/packages/32/56/8a7ca5d2cd2cda1d245d34b1c9a942920a718082ae8e54e5f3e5a58b7add/pydantic_core-2.33.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:329467cecfb529c925cf2bbd4d60d2c509bc2fb52a20c1045bf09bb70971a9c1", size = 2066757, upload-time = "2025-04-23T18:33:30.645Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "8.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, +] + +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631, upload-time = "2024-08-06T20:33:50.674Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/95/a3fac87cb7158e231b5a6012e438c647e1a87f09f8e0d123acec8ab8bf71/PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", size = 184199, upload-time = "2024-08-06T20:31:40.178Z" }, + { url = "https://files.pythonhosted.org/packages/c7/7a/68bd47624dab8fd4afbfd3c48e3b79efe09098ae941de5b58abcbadff5cb/PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", size = 171758, upload-time = "2024-08-06T20:31:42.173Z" }, + { url = "https://files.pythonhosted.org/packages/49/ee/14c54df452143b9ee9f0f29074d7ca5516a36edb0b4cc40c3f280131656f/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", size = 718463, upload-time = "2024-08-06T20:31:44.263Z" }, + { url = "https://files.pythonhosted.org/packages/4d/61/de363a97476e766574650d742205be468921a7b532aa2499fcd886b62530/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", size = 719280, upload-time = "2024-08-06T20:31:50.199Z" }, + { url = "https://files.pythonhosted.org/packages/6b/4e/1523cb902fd98355e2e9ea5e5eb237cbc5f3ad5f3075fa65087aa0ecb669/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", size = 751239, upload-time = "2024-08-06T20:31:52.292Z" }, + { url = "https://files.pythonhosted.org/packages/b7/33/5504b3a9a4464893c32f118a9cc045190a91637b119a9c881da1cf6b7a72/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", size = 695802, upload-time = "2024-08-06T20:31:53.836Z" }, + { url = "https://files.pythonhosted.org/packages/5c/20/8347dcabd41ef3a3cdc4f7b7a2aff3d06598c8779faa189cdbf878b626a4/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", size = 720527, upload-time = "2024-08-06T20:31:55.565Z" }, + { url = "https://files.pythonhosted.org/packages/be/aa/5afe99233fb360d0ff37377145a949ae258aaab831bde4792b32650a4378/PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", size = 144052, upload-time = "2024-08-06T20:31:56.914Z" }, + { url = "https://files.pythonhosted.org/packages/b5/84/0fa4b06f6d6c958d207620fc60005e241ecedceee58931bb20138e1e5776/PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", size = 161774, upload-time = "2024-08-06T20:31:58.304Z" }, + { url = "https://files.pythonhosted.org/packages/f8/aa/7af4e81f7acba21a4c6be026da38fd2b872ca46226673c89a758ebdc4fd2/PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", size = 184612, upload-time = "2024-08-06T20:32:03.408Z" }, + { url = "https://files.pythonhosted.org/packages/8b/62/b9faa998fd185f65c1371643678e4d58254add437edb764a08c5a98fb986/PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", size = 172040, upload-time = "2024-08-06T20:32:04.926Z" }, + { url = "https://files.pythonhosted.org/packages/ad/0c/c804f5f922a9a6563bab712d8dcc70251e8af811fce4524d57c2c0fd49a4/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", size = 736829, upload-time = "2024-08-06T20:32:06.459Z" }, + { url = "https://files.pythonhosted.org/packages/51/16/6af8d6a6b210c8e54f1406a6b9481febf9c64a3109c541567e35a49aa2e7/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", size = 764167, upload-time = "2024-08-06T20:32:08.338Z" }, + { url = "https://files.pythonhosted.org/packages/75/e4/2c27590dfc9992f73aabbeb9241ae20220bd9452df27483b6e56d3975cc5/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", size = 762952, upload-time = "2024-08-06T20:32:14.124Z" }, + { url = "https://files.pythonhosted.org/packages/9b/97/ecc1abf4a823f5ac61941a9c00fe501b02ac3ab0e373c3857f7d4b83e2b6/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4", size = 735301, upload-time = "2024-08-06T20:32:16.17Z" }, + { url = "https://files.pythonhosted.org/packages/45/73/0f49dacd6e82c9430e46f4a027baa4ca205e8b0a9dce1397f44edc23559d/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", size = 756638, upload-time = "2024-08-06T20:32:18.555Z" }, + { url = "https://files.pythonhosted.org/packages/22/5f/956f0f9fc65223a58fbc14459bf34b4cc48dec52e00535c79b8db361aabd/PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", size = 143850, upload-time = "2024-08-06T20:32:19.889Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/8da0bbe2ab9dcdd11f4f4557ccaf95c10b9811b13ecced089d43ce59c3c8/PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", size = 161980, upload-time = "2024-08-06T20:32:21.273Z" }, + { url = "https://files.pythonhosted.org/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873, upload-time = "2024-08-06T20:32:25.131Z" }, + { url = "https://files.pythonhosted.org/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302, upload-time = "2024-08-06T20:32:26.511Z" }, + { url = "https://files.pythonhosted.org/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154, upload-time = "2024-08-06T20:32:28.363Z" }, + { url = "https://files.pythonhosted.org/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223, upload-time = "2024-08-06T20:32:30.058Z" }, + { url = "https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542, upload-time = "2024-08-06T20:32:31.881Z" }, + { url = "https://files.pythonhosted.org/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164, upload-time = "2024-08-06T20:32:37.083Z" }, + { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611, upload-time = "2024-08-06T20:32:38.898Z" }, + { url = "https://files.pythonhosted.org/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591, upload-time = "2024-08-06T20:32:40.241Z" }, + { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338, upload-time = "2024-08-06T20:32:41.93Z" }, + { url = "https://files.pythonhosted.org/packages/ef/e3/3af305b830494fa85d95f6d95ef7fa73f2ee1cc8ef5b495c7c3269fb835f/PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", size = 181309, upload-time = "2024-08-06T20:32:43.4Z" }, + { url = "https://files.pythonhosted.org/packages/45/9f/3b1c20a0b7a3200524eb0076cc027a970d320bd3a6592873c85c92a08731/PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", size = 171679, upload-time = "2024-08-06T20:32:44.801Z" }, + { url = "https://files.pythonhosted.org/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428, upload-time = "2024-08-06T20:32:46.432Z" }, + { url = "https://files.pythonhosted.org/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361, upload-time = "2024-08-06T20:32:51.188Z" }, + { url = "https://files.pythonhosted.org/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523, upload-time = "2024-08-06T20:32:53.019Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660, upload-time = "2024-08-06T20:32:54.708Z" }, + { url = "https://files.pythonhosted.org/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597, upload-time = "2024-08-06T20:32:56.985Z" }, + { url = "https://files.pythonhosted.org/packages/14/0d/e2c3b43bbce3cf6bd97c840b46088a3031085179e596d4929729d8d68270/PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", size = 140527, upload-time = "2024-08-06T20:33:03.001Z" }, + { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" }, +] + +[[package]] +name = "regex" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/5a/4c63457fbcaf19d138d72b2e9b39405954f98c0349b31c601bfcb151582c/regex-2025.9.1.tar.gz", hash = "sha256:88ac07b38d20b54d79e704e38aa3bd2c0f8027432164226bdee201a1c0c9c9ff", size = 400852, upload-time = "2025-09-01T22:10:10.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/c1/ed9ef923156105a78aa004f9390e5dd87eadc29f5ca8840f172cadb638de/regex-2025.9.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c5aa2a6a73bf218515484b36a0d20c6ad9dc63f6339ff6224147b0e2c095ee55", size = 484813, upload-time = "2025-09-01T22:07:45.528Z" }, + { url = "https://files.pythonhosted.org/packages/05/de/97957618a774c67f892609eee2fafe3e30703fbbba66de5e6b79d7196dbc/regex-2025.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c2ff5c01d5e47ad5fc9d31bcd61e78c2fa0068ed00cab86b7320214446da766", size = 288981, upload-time = "2025-09-01T22:07:48.464Z" }, + { url = "https://files.pythonhosted.org/packages/3c/b0/441afadd0a6ffccbd58a9663e5bdd182daa237893e5f8ceec6ff9df4418a/regex-2025.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d49dc84e796b666181de8a9973284cad6616335f01b52bf099643253094920fc", size = 286608, upload-time = "2025-09-01T22:07:50.484Z" }, + { url = "https://files.pythonhosted.org/packages/6e/cf/d89aecaf17e999ab11a3ef73fc9ab8b64f4e156f121250ef84340b35338d/regex-2025.9.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9914fe1040874f83c15fcea86d94ea54091b0666eab330aaab69e30d106aabe", size = 780459, upload-time = "2025-09-01T22:07:52.34Z" }, + { url = "https://files.pythonhosted.org/packages/f6/05/05884594a9975a29597917bbdd6837f7b97e8ac23faf22d628aa781e58f7/regex-2025.9.1-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e71bceb3947362ec5eabd2ca0870bb78eae4edfc60c6c21495133c01b6cd2df4", size = 849276, upload-time = "2025-09-01T22:07:54.591Z" }, + { url = "https://files.pythonhosted.org/packages/8c/8d/2b3067506838d02096bf107beb129b2ce328cdf776d6474b7f542c0a7bfd/regex-2025.9.1-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:67a74456f410fe5e869239ee7a5423510fe5121549af133809d9591a8075893f", size = 897320, upload-time = "2025-09-01T22:07:56.129Z" }, + { url = "https://files.pythonhosted.org/packages/9e/b3/0f9f7766e980b900df0ba9901b52871a2e4203698fb35cdebd219240d5f7/regex-2025.9.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5c3b96ed0223b32dbdc53a83149b6de7ca3acd5acd9c8e64b42a166228abe29c", size = 789931, upload-time = "2025-09-01T22:07:57.834Z" }, + { url = "https://files.pythonhosted.org/packages/47/9f/7b2f29c8f8b698eb44be5fc68e8b9c8d32e99635eac5defc98de114e9f35/regex-2025.9.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:113d5aa950f428faf46fd77d452df62ebb4cc6531cb619f6cc30a369d326bfbd", size = 780764, upload-time = "2025-09-01T22:07:59.413Z" }, + { url = "https://files.pythonhosted.org/packages/ac/ac/56176caa86155c14462531eb0a4ddc450d17ba8875001122b3b7c0cb01bf/regex-2025.9.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fcdeb38de4f7f3d69d798f4f371189061446792a84e7c92b50054c87aae9c07c", size = 773610, upload-time = "2025-09-01T22:08:01.042Z" }, + { url = "https://files.pythonhosted.org/packages/39/e8/9d6b9bd43998268a9de2f35602077519cacc9cb149f7381758cf8f502ba7/regex-2025.9.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4bcdff370509164b67a6c8ec23c9fb40797b72a014766fdc159bb809bd74f7d8", size = 844090, upload-time = "2025-09-01T22:08:02.94Z" }, + { url = "https://files.pythonhosted.org/packages/fd/92/d89743b089005cae4cb81cc2fe177e180b7452e60f29de53af34349640f8/regex-2025.9.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:7383efdf6e8e8c61d85e00cfb2e2e18da1a621b8bfb4b0f1c2747db57b942b8f", size = 834775, upload-time = "2025-09-01T22:08:04.781Z" }, + { url = "https://files.pythonhosted.org/packages/01/8f/86a3e0aaa89295d2a3445bb238e56369963ef6b02a5b4aa3362f4e687413/regex-2025.9.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1ec2bd3bdf0f73f7e9f48dca550ba7d973692d5e5e9a90ac42cc5f16c4432d8b", size = 778521, upload-time = "2025-09-01T22:08:06.596Z" }, + { url = "https://files.pythonhosted.org/packages/3e/df/72072acb370ee8577c255717f8a58264f1d0de40aa3c9e6ebd5271cac633/regex-2025.9.1-cp310-cp310-win32.whl", hash = "sha256:9627e887116c4e9c0986d5c3b4f52bcfe3df09850b704f62ec3cbf177a0ae374", size = 264105, upload-time = "2025-09-01T22:08:08.708Z" }, + { url = "https://files.pythonhosted.org/packages/97/73/fb82faaf0375aeaa1bb675008246c79b6779fa5688585a35327610ea0e2e/regex-2025.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:94533e32dc0065eca43912ee6649c90ea0681d59f56d43c45b5bcda9a740b3dd", size = 276131, upload-time = "2025-09-01T22:08:10.156Z" }, + { url = "https://files.pythonhosted.org/packages/d3/3a/77d7718a2493e54725494f44da1a1e55704743dc4b8fabe5b0596f7b8014/regex-2025.9.1-cp310-cp310-win_arm64.whl", hash = "sha256:a874a61bb580d48642ffd338570ee24ab13fa023779190513fcacad104a6e251", size = 268462, upload-time = "2025-09-01T22:08:11.651Z" }, + { url = "https://files.pythonhosted.org/packages/06/4d/f741543c0c59f96c6625bc6c11fea1da2e378b7d293ffff6f318edc0ce14/regex-2025.9.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e5bcf112b09bfd3646e4db6bf2e598534a17d502b0c01ea6550ba4eca780c5e6", size = 484811, upload-time = "2025-09-01T22:08:12.834Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bd/27e73e92635b6fbd51afc26a414a3133243c662949cd1cda677fe7bb09bd/regex-2025.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:67a0295a3c31d675a9ee0238d20238ff10a9a2fdb7a1323c798fc7029578b15c", size = 288977, upload-time = "2025-09-01T22:08:14.499Z" }, + { url = "https://files.pythonhosted.org/packages/eb/7d/7dc0c6efc8bc93cd6e9b947581f5fde8a5dbaa0af7c4ec818c5729fdc807/regex-2025.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea8267fbadc7d4bd7c1301a50e85c2ff0de293ff9452a1a9f8d82c6cafe38179", size = 286606, upload-time = "2025-09-01T22:08:15.881Z" }, + { url = "https://files.pythonhosted.org/packages/d1/01/9b5c6dd394f97c8f2c12f6e8f96879c9ac27292a718903faf2e27a0c09f6/regex-2025.9.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6aeff21de7214d15e928fb5ce757f9495214367ba62875100d4c18d293750cc1", size = 792436, upload-time = "2025-09-01T22:08:17.38Z" }, + { url = "https://files.pythonhosted.org/packages/fc/24/b7430cfc6ee34bbb3db6ff933beb5e7692e5cc81e8f6f4da63d353566fb0/regex-2025.9.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d89f1bbbbbc0885e1c230f7770d5e98f4f00b0ee85688c871d10df8b184a6323", size = 858705, upload-time = "2025-09-01T22:08:19.037Z" }, + { url = "https://files.pythonhosted.org/packages/d6/98/155f914b4ea6ae012663188545c4f5216c11926d09b817127639d618b003/regex-2025.9.1-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ca3affe8ddea498ba9d294ab05f5f2d3b5ad5d515bc0d4a9016dd592a03afe52", size = 905881, upload-time = "2025-09-01T22:08:20.377Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a7/a470e7bc8259c40429afb6d6a517b40c03f2f3e455c44a01abc483a1c512/regex-2025.9.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:91892a7a9f0a980e4c2c85dd19bc14de2b219a3a8867c4b5664b9f972dcc0c78", size = 798968, upload-time = "2025-09-01T22:08:22.081Z" }, + { url = "https://files.pythonhosted.org/packages/1d/fa/33f6fec4d41449fea5f62fdf5e46d668a1c046730a7f4ed9f478331a8e3a/regex-2025.9.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e1cb40406f4ae862710615f9f636c1e030fd6e6abe0e0f65f6a695a2721440c6", size = 781884, upload-time = "2025-09-01T22:08:23.832Z" }, + { url = "https://files.pythonhosted.org/packages/42/de/2b45f36ab20da14eedddf5009d370625bc5942d9953fa7e5037a32d66843/regex-2025.9.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:94f6cff6f7e2149c7e6499a6ecd4695379eeda8ccbccb9726e8149f2fe382e92", size = 852935, upload-time = "2025-09-01T22:08:25.536Z" }, + { url = "https://files.pythonhosted.org/packages/1e/f9/878f4fc92c87e125e27aed0f8ee0d1eced9b541f404b048f66f79914475a/regex-2025.9.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:6c0226fb322b82709e78c49cc33484206647f8a39954d7e9de1567f5399becd0", size = 844340, upload-time = "2025-09-01T22:08:27.141Z" }, + { url = "https://files.pythonhosted.org/packages/90/c2/5b6f2bce6ece5f8427c718c085eca0de4bbb4db59f54db77aa6557aef3e9/regex-2025.9.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a12f59c7c380b4fcf7516e9cbb126f95b7a9518902bcf4a852423ff1dcd03e6a", size = 787238, upload-time = "2025-09-01T22:08:28.75Z" }, + { url = "https://files.pythonhosted.org/packages/47/66/1ef1081c831c5b611f6f55f6302166cfa1bc9574017410ba5595353f846a/regex-2025.9.1-cp311-cp311-win32.whl", hash = "sha256:49865e78d147a7a4f143064488da5d549be6bfc3f2579e5044cac61f5c92edd4", size = 264118, upload-time = "2025-09-01T22:08:30.388Z" }, + { url = "https://files.pythonhosted.org/packages/ad/e0/8adc550d7169df1d6b9be8ff6019cda5291054a0107760c2f30788b6195f/regex-2025.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:d34b901f6f2f02ef60f4ad3855d3a02378c65b094efc4b80388a3aeb700a5de7", size = 276151, upload-time = "2025-09-01T22:08:32.073Z" }, + { url = "https://files.pythonhosted.org/packages/cb/bd/46fef29341396d955066e55384fb93b0be7d64693842bf4a9a398db6e555/regex-2025.9.1-cp311-cp311-win_arm64.whl", hash = "sha256:47d7c2dab7e0b95b95fd580087b6ae196039d62306a592fa4e162e49004b6299", size = 268460, upload-time = "2025-09-01T22:08:33.281Z" }, + { url = "https://files.pythonhosted.org/packages/39/ef/a0372febc5a1d44c1be75f35d7e5aff40c659ecde864d7fa10e138f75e74/regex-2025.9.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:84a25164bd8dcfa9f11c53f561ae9766e506e580b70279d05a7946510bdd6f6a", size = 486317, upload-time = "2025-09-01T22:08:34.529Z" }, + { url = "https://files.pythonhosted.org/packages/b5/25/d64543fb7eb41a1024786d518cc57faf1ce64aa6e9ddba097675a0c2f1d2/regex-2025.9.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:645e88a73861c64c1af558dd12294fb4e67b5c1eae0096a60d7d8a2143a611c7", size = 289698, upload-time = "2025-09-01T22:08:36.162Z" }, + { url = "https://files.pythonhosted.org/packages/d8/dc/fbf31fc60be317bd9f6f87daa40a8a9669b3b392aa8fe4313df0a39d0722/regex-2025.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:10a450cba5cd5409526ee1d4449f42aad38dd83ac6948cbd6d7f71ca7018f7db", size = 287242, upload-time = "2025-09-01T22:08:37.794Z" }, + { url = "https://files.pythonhosted.org/packages/0f/74/f933a607a538f785da5021acf5323961b4620972e2c2f1f39b6af4b71db7/regex-2025.9.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e9dc5991592933a4192c166eeb67b29d9234f9c86344481173d1bc52f73a7104", size = 797441, upload-time = "2025-09-01T22:08:39.108Z" }, + { url = "https://files.pythonhosted.org/packages/89/d0/71fc49b4f20e31e97f199348b8c4d6e613e7b6a54a90eb1b090c2b8496d7/regex-2025.9.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a32291add816961aab472f4fad344c92871a2ee33c6c219b6598e98c1f0108f2", size = 862654, upload-time = "2025-09-01T22:08:40.586Z" }, + { url = "https://files.pythonhosted.org/packages/59/05/984edce1411a5685ba9abbe10d42cdd9450aab4a022271f9585539788150/regex-2025.9.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:588c161a68a383478e27442a678e3b197b13c5ba51dbba40c1ccb8c4c7bee9e9", size = 910862, upload-time = "2025-09-01T22:08:42.416Z" }, + { url = "https://files.pythonhosted.org/packages/b2/02/5c891bb5fe0691cc1bad336e3a94b9097fbcf9707ec8ddc1dce9f0397289/regex-2025.9.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47829ffaf652f30d579534da9085fe30c171fa2a6744a93d52ef7195dc38218b", size = 801991, upload-time = "2025-09-01T22:08:44.072Z" }, + { url = "https://files.pythonhosted.org/packages/f1/ae/fd10d6ad179910f7a1b3e0a7fde1ef8bb65e738e8ac4fd6ecff3f52252e4/regex-2025.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e978e5a35b293ea43f140c92a3269b6ab13fe0a2bf8a881f7ac740f5a6ade85", size = 786651, upload-time = "2025-09-01T22:08:46.079Z" }, + { url = "https://files.pythonhosted.org/packages/30/cf/9d686b07bbc5bf94c879cc168db92542d6bc9fb67088d03479fef09ba9d3/regex-2025.9.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4cf09903e72411f4bf3ac1eddd624ecfd423f14b2e4bf1c8b547b72f248b7bf7", size = 856556, upload-time = "2025-09-01T22:08:48.376Z" }, + { url = "https://files.pythonhosted.org/packages/91/9d/302f8a29bb8a49528abbab2d357a793e2a59b645c54deae0050f8474785b/regex-2025.9.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:d016b0f77be63e49613c9e26aaf4a242f196cd3d7a4f15898f5f0ab55c9b24d2", size = 849001, upload-time = "2025-09-01T22:08:50.067Z" }, + { url = "https://files.pythonhosted.org/packages/93/fa/b4c6dbdedc85ef4caec54c817cd5f4418dbfa2453214119f2538082bf666/regex-2025.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:656563e620de6908cd1c9d4f7b9e0777e3341ca7db9d4383bcaa44709c90281e", size = 788138, upload-time = "2025-09-01T22:08:51.933Z" }, + { url = "https://files.pythonhosted.org/packages/4a/1b/91ee17a3cbf87f81e8c110399279d0e57f33405468f6e70809100f2ff7d8/regex-2025.9.1-cp312-cp312-win32.whl", hash = "sha256:df33f4ef07b68f7ab637b1dbd70accbf42ef0021c201660656601e8a9835de45", size = 264524, upload-time = "2025-09-01T22:08:53.75Z" }, + { url = "https://files.pythonhosted.org/packages/92/28/6ba31cce05b0f1ec6b787921903f83bd0acf8efde55219435572af83c350/regex-2025.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:5aba22dfbc60cda7c0853516104724dc904caa2db55f2c3e6e984eb858d3edf3", size = 275489, upload-time = "2025-09-01T22:08:55.037Z" }, + { url = "https://files.pythonhosted.org/packages/bd/ed/ea49f324db00196e9ef7fe00dd13c6164d5173dd0f1bbe495e61bb1fb09d/regex-2025.9.1-cp312-cp312-win_arm64.whl", hash = "sha256:ec1efb4c25e1849c2685fa95da44bfde1b28c62d356f9c8d861d4dad89ed56e9", size = 268589, upload-time = "2025-09-01T22:08:56.369Z" }, + { url = "https://files.pythonhosted.org/packages/98/25/b2959ce90c6138c5142fe5264ee1f9b71a0c502ca4c7959302a749407c79/regex-2025.9.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:bc6834727d1b98d710a63e6c823edf6ffbf5792eba35d3fa119531349d4142ef", size = 485932, upload-time = "2025-09-01T22:08:57.913Z" }, + { url = "https://files.pythonhosted.org/packages/49/2e/6507a2a85f3f2be6643438b7bd976e67ad73223692d6988eb1ff444106d3/regex-2025.9.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c3dc05b6d579875719bccc5f3037b4dc80433d64e94681a0061845bd8863c025", size = 289568, upload-time = "2025-09-01T22:08:59.258Z" }, + { url = "https://files.pythonhosted.org/packages/c7/d8/de4a4b57215d99868f1640e062a7907e185ec7476b4b689e2345487c1ff4/regex-2025.9.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:22213527df4c985ec4a729b055a8306272d41d2f45908d7bacb79be0fa7a75ad", size = 286984, upload-time = "2025-09-01T22:09:00.835Z" }, + { url = "https://files.pythonhosted.org/packages/03/15/e8cb403403a57ed316e80661db0e54d7aa2efcd85cb6156f33cc18746922/regex-2025.9.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8e3f6e3c5a5a1adc3f7ea1b5aec89abfc2f4fbfba55dafb4343cd1d084f715b2", size = 797514, upload-time = "2025-09-01T22:09:02.538Z" }, + { url = "https://files.pythonhosted.org/packages/e4/26/2446f2b9585fed61faaa7e2bbce3aca7dd8df6554c32addee4c4caecf24a/regex-2025.9.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:bcb89c02a0d6c2bec9b0bb2d8c78782699afe8434493bfa6b4021cc51503f249", size = 862586, upload-time = "2025-09-01T22:09:04.322Z" }, + { url = "https://files.pythonhosted.org/packages/fd/b8/82ffbe9c0992c31bbe6ae1c4b4e21269a5df2559102b90543c9b56724c3c/regex-2025.9.1-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b0e2f95413eb0c651cd1516a670036315b91b71767af83bc8525350d4375ccba", size = 910815, upload-time = "2025-09-01T22:09:05.978Z" }, + { url = "https://files.pythonhosted.org/packages/2f/d8/7303ea38911759c1ee30cc5bc623ee85d3196b733c51fd6703c34290a8d9/regex-2025.9.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:09a41dc039e1c97d3c2ed3e26523f748e58c4de3ea7a31f95e1cf9ff973fff5a", size = 802042, upload-time = "2025-09-01T22:09:07.865Z" }, + { url = "https://files.pythonhosted.org/packages/fc/0e/6ad51a55ed4b5af512bb3299a05d33309bda1c1d1e1808fa869a0bed31bc/regex-2025.9.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f0b4258b161094f66857a26ee938d3fe7b8a5063861e44571215c44fbf0e5df", size = 786764, upload-time = "2025-09-01T22:09:09.362Z" }, + { url = "https://files.pythonhosted.org/packages/8d/d5/394e3ffae6baa5a9217bbd14d96e0e5da47bb069d0dbb8278e2681a2b938/regex-2025.9.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:bf70e18ac390e6977ea7e56f921768002cb0fa359c4199606c7219854ae332e0", size = 856557, upload-time = "2025-09-01T22:09:11.129Z" }, + { url = "https://files.pythonhosted.org/packages/cd/80/b288d3910c41194ad081b9fb4b371b76b0bbfdce93e7709fc98df27b37dc/regex-2025.9.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b84036511e1d2bb0a4ff1aec26951caa2dea8772b223c9e8a19ed8885b32dbac", size = 849108, upload-time = "2025-09-01T22:09:12.877Z" }, + { url = "https://files.pythonhosted.org/packages/d1/cd/5ec76bf626d0d5abdc277b7a1734696f5f3d14fbb4a3e2540665bc305d85/regex-2025.9.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c2e05dcdfe224047f2a59e70408274c325d019aad96227ab959403ba7d58d2d7", size = 788201, upload-time = "2025-09-01T22:09:14.561Z" }, + { url = "https://files.pythonhosted.org/packages/b5/36/674672f3fdead107565a2499f3007788b878188acec6d42bc141c5366c2c/regex-2025.9.1-cp313-cp313-win32.whl", hash = "sha256:3b9a62107a7441b81ca98261808fed30ae36ba06c8b7ee435308806bd53c1ed8", size = 264508, upload-time = "2025-09-01T22:09:16.193Z" }, + { url = "https://files.pythonhosted.org/packages/83/ad/931134539515eb64ce36c24457a98b83c1b2e2d45adf3254b94df3735a76/regex-2025.9.1-cp313-cp313-win_amd64.whl", hash = "sha256:b38afecc10c177eb34cfae68d669d5161880849ba70c05cbfbe409f08cc939d7", size = 275469, upload-time = "2025-09-01T22:09:17.462Z" }, + { url = "https://files.pythonhosted.org/packages/24/8c/96d34e61c0e4e9248836bf86d69cb224fd222f270fa9045b24e218b65604/regex-2025.9.1-cp313-cp313-win_arm64.whl", hash = "sha256:ec329890ad5e7ed9fc292858554d28d58d56bf62cf964faf0aa57964b21155a0", size = 268586, upload-time = "2025-09-01T22:09:18.948Z" }, + { url = "https://files.pythonhosted.org/packages/21/b1/453cbea5323b049181ec6344a803777914074b9726c9c5dc76749966d12d/regex-2025.9.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:72fb7a016467d364546f22b5ae86c45680a4e0de6b2a6f67441d22172ff641f1", size = 486111, upload-time = "2025-09-01T22:09:20.734Z" }, + { url = "https://files.pythonhosted.org/packages/f6/0e/92577f197bd2f7652c5e2857f399936c1876978474ecc5b068c6d8a79c86/regex-2025.9.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c9527fa74eba53f98ad86be2ba003b3ebe97e94b6eb2b916b31b5f055622ef03", size = 289520, upload-time = "2025-09-01T22:09:22.249Z" }, + { url = "https://files.pythonhosted.org/packages/af/c6/b472398116cca7ea5a6c4d5ccd0fc543f7fd2492cb0c48d2852a11972f73/regex-2025.9.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c905d925d194c83a63f92422af7544ec188301451b292c8b487f0543726107ca", size = 287215, upload-time = "2025-09-01T22:09:23.657Z" }, + { url = "https://files.pythonhosted.org/packages/cf/11/f12ecb0cf9ca792a32bb92f758589a84149017467a544f2f6bfb45c0356d/regex-2025.9.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:74df7c74a63adcad314426b1f4ea6054a5ab25d05b0244f0c07ff9ce640fa597", size = 797855, upload-time = "2025-09-01T22:09:25.197Z" }, + { url = "https://files.pythonhosted.org/packages/46/88/bbb848f719a540fb5997e71310f16f0b33a92c5d4b4d72d4311487fff2a3/regex-2025.9.1-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4f6e935e98ea48c7a2e8be44494de337b57a204470e7f9c9c42f912c414cd6f5", size = 863363, upload-time = "2025-09-01T22:09:26.705Z" }, + { url = "https://files.pythonhosted.org/packages/54/a9/2321eb3e2838f575a78d48e03c1e83ea61bd08b74b7ebbdeca8abc50fc25/regex-2025.9.1-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4a62d033cd9ebefc7c5e466731a508dfabee827d80b13f455de68a50d3c2543d", size = 910202, upload-time = "2025-09-01T22:09:28.906Z" }, + { url = "https://files.pythonhosted.org/packages/33/07/d1d70835d7d11b7e126181f316f7213c4572ecf5c5c97bdbb969fb1f38a2/regex-2025.9.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ef971ebf2b93bdc88d8337238be4dfb851cc97ed6808eb04870ef67589415171", size = 801808, upload-time = "2025-09-01T22:09:30.733Z" }, + { url = "https://files.pythonhosted.org/packages/13/d1/29e4d1bed514ef2bf3a4ead3cb8bb88ca8af94130239a4e68aa765c35b1c/regex-2025.9.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d936a1db208bdca0eca1f2bb2c1ba1d8370b226785c1e6db76e32a228ffd0ad5", size = 786824, upload-time = "2025-09-01T22:09:32.61Z" }, + { url = "https://files.pythonhosted.org/packages/33/27/20d8ccb1bee460faaa851e6e7cc4cfe852a42b70caa1dca22721ba19f02f/regex-2025.9.1-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:7e786d9e4469698fc63815b8de08a89165a0aa851720eb99f5e0ea9d51dd2b6a", size = 857406, upload-time = "2025-09-01T22:09:34.117Z" }, + { url = "https://files.pythonhosted.org/packages/74/fe/60c6132262dc36430d51e0c46c49927d113d3a38c1aba6a26c7744c84cf3/regex-2025.9.1-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:6b81d7dbc5466ad2c57ce3a0ddb717858fe1a29535c8866f8514d785fdb9fc5b", size = 848593, upload-time = "2025-09-01T22:09:35.598Z" }, + { url = "https://files.pythonhosted.org/packages/cc/ae/2d4ff915622fabbef1af28387bf71e7f2f4944a348b8460d061e85e29bf0/regex-2025.9.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:cd4890e184a6feb0ef195338a6ce68906a8903a0f2eb7e0ab727dbc0a3156273", size = 787951, upload-time = "2025-09-01T22:09:37.139Z" }, + { url = "https://files.pythonhosted.org/packages/85/37/dc127703a9e715a284cc2f7dbdd8a9776fd813c85c126eddbcbdd1ca5fec/regex-2025.9.1-cp314-cp314-win32.whl", hash = "sha256:34679a86230e46164c9e0396b56cab13c0505972343880b9e705083cc5b8ec86", size = 269833, upload-time = "2025-09-01T22:09:39.245Z" }, + { url = "https://files.pythonhosted.org/packages/83/bf/4bed4d3d0570e16771defd5f8f15f7ea2311edcbe91077436d6908956c4a/regex-2025.9.1-cp314-cp314-win_amd64.whl", hash = "sha256:a1196e530a6bfa5f4bde029ac5b0295a6ecfaaffbfffede4bbaf4061d9455b70", size = 278742, upload-time = "2025-09-01T22:09:40.651Z" }, + { url = "https://files.pythonhosted.org/packages/cf/3e/7d7ac6fd085023312421e0d69dfabdfb28e116e513fadbe9afe710c01893/regex-2025.9.1-cp314-cp314-win_arm64.whl", hash = "sha256:f46d525934871ea772930e997d577d48c6983e50f206ff7b66d4ac5f8941e993", size = 271860, upload-time = "2025-09-01T22:09:42.413Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "sentry-sdk" +version = "2.35.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bd/79/0ecb942f3f1ad26c40c27f81ff82392d85c01d26a45e3c72c2b37807e680/sentry_sdk-2.35.2.tar.gz", hash = "sha256:e9e8f3c795044beb59f2c8f4c6b9b0f9779e5e604099882df05eec525e782cc6", size = 343377, upload-time = "2025-09-01T11:00:58.633Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/91/a43308dc82a0e32d80cd0dfdcfca401ecbd0f431ab45f24e48bb97b7800d/sentry_sdk-2.35.2-py2.py3-none-any.whl", hash = "sha256:38c98e3cbb620dd3dd80a8d6e39c753d453dd41f8a9df581b0584c19a52bc926", size = 363975, upload-time = "2025-09-01T11:00:56.574Z" }, +] + +[[package]] +name = "setuptools" +version = "80.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "starlette" +version = "0.48.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/72/2db2f49247d0a18b4f1bb9a5a39a0162869acf235f3a96418363947b3d46/starlette-0.48.0-py3-none-any.whl", hash = "sha256:0764ca97b097582558ecb498132ed0c7d942f233f365b86ba37770e026510659", size = 73736, upload-time = "2025-09-13T08:41:03.869Z" }, +] + +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + +[[package]] +name = "tiktoken" +version = "0.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "regex" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/86/ad0155a37c4f310935d5ac0b1ccf9bdb635dcb906e0a9a26b616dd55825a/tiktoken-0.11.0.tar.gz", hash = "sha256:3c518641aee1c52247c2b97e74d8d07d780092af79d5911a6ab5e79359d9b06a", size = 37648, upload-time = "2025-08-08T23:58:08.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/4d/c6a2e7dca2b4f2e9e0bfd62b3fe4f114322e2c028cfba905a72bc76ce479/tiktoken-0.11.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:8a9b517d6331d7103f8bef29ef93b3cca95fa766e293147fe7bacddf310d5917", size = 1059937, upload-time = "2025-08-08T23:57:28.57Z" }, + { url = "https://files.pythonhosted.org/packages/41/54/3739d35b9f94cb8dc7b0db2edca7192d5571606aa2369a664fa27e811804/tiktoken-0.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b4ddb1849e6bf0afa6cc1c5d809fb980ca240a5fffe585a04e119519758788c0", size = 999230, upload-time = "2025-08-08T23:57:30.241Z" }, + { url = "https://files.pythonhosted.org/packages/dd/f4/ec8d43338d28d53513004ebf4cd83732a135d11011433c58bf045890cc10/tiktoken-0.11.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10331d08b5ecf7a780b4fe4d0281328b23ab22cdb4ff65e68d56caeda9940ecc", size = 1130076, upload-time = "2025-08-08T23:57:31.706Z" }, + { url = "https://files.pythonhosted.org/packages/94/80/fb0ada0a882cb453caf519a4bf0d117c2a3ee2e852c88775abff5413c176/tiktoken-0.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b062c82300341dc87e0258c69f79bed725f87e753c21887aea90d272816be882", size = 1183942, upload-time = "2025-08-08T23:57:33.142Z" }, + { url = "https://files.pythonhosted.org/packages/2f/e9/6c104355b463601719582823f3ea658bc3aa7c73d1b3b7553ebdc48468ce/tiktoken-0.11.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:195d84bec46169af3b1349a1495c151d37a0ff4cba73fd08282736be7f92cc6c", size = 1244705, upload-time = "2025-08-08T23:57:34.594Z" }, + { url = "https://files.pythonhosted.org/packages/94/75/eaa6068f47e8b3f0aab9e05177cce2cf5aa2cc0ca93981792e620d4d4117/tiktoken-0.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:fe91581b0ecdd8783ce8cb6e3178f2260a3912e8724d2f2d49552b98714641a1", size = 884152, upload-time = "2025-08-08T23:57:36.18Z" }, + { url = "https://files.pythonhosted.org/packages/8a/91/912b459799a025d2842566fe1e902f7f50d54a1ce8a0f236ab36b5bd5846/tiktoken-0.11.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4ae374c46afadad0f501046db3da1b36cd4dfbfa52af23c998773682446097cf", size = 1059743, upload-time = "2025-08-08T23:57:37.516Z" }, + { url = "https://files.pythonhosted.org/packages/8c/e9/6faa6870489ce64f5f75dcf91512bf35af5864583aee8fcb0dcb593121f5/tiktoken-0.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:25a512ff25dc6c85b58f5dd4f3d8c674dc05f96b02d66cdacf628d26a4e4866b", size = 999334, upload-time = "2025-08-08T23:57:38.595Z" }, + { url = "https://files.pythonhosted.org/packages/a1/3e/a05d1547cf7db9dc75d1461cfa7b556a3b48e0516ec29dfc81d984a145f6/tiktoken-0.11.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2130127471e293d385179c1f3f9cd445070c0772be73cdafb7cec9a3684c0458", size = 1129402, upload-time = "2025-08-08T23:57:39.627Z" }, + { url = "https://files.pythonhosted.org/packages/34/9a/db7a86b829e05a01fd4daa492086f708e0a8b53952e1dbc9d380d2b03677/tiktoken-0.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21e43022bf2c33f733ea9b54f6a3f6b4354b909f5a73388fb1b9347ca54a069c", size = 1184046, upload-time = "2025-08-08T23:57:40.689Z" }, + { url = "https://files.pythonhosted.org/packages/9d/bb/52edc8e078cf062ed749248f1454e9e5cfd09979baadb830b3940e522015/tiktoken-0.11.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:adb4e308eb64380dc70fa30493e21c93475eaa11669dea313b6bbf8210bfd013", size = 1244691, upload-time = "2025-08-08T23:57:42.251Z" }, + { url = "https://files.pythonhosted.org/packages/60/d9/884b6cd7ae2570ecdcaffa02b528522b18fef1cbbfdbcaa73799807d0d3b/tiktoken-0.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:ece6b76bfeeb61a125c44bbefdfccc279b5288e6007fbedc0d32bfec602df2f2", size = 884392, upload-time = "2025-08-08T23:57:43.628Z" }, + { url = "https://files.pythonhosted.org/packages/e7/9e/eceddeffc169fc75fe0fd4f38471309f11cb1906f9b8aa39be4f5817df65/tiktoken-0.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fd9e6b23e860973cf9526544e220b223c60badf5b62e80a33509d6d40e6c8f5d", size = 1055199, upload-time = "2025-08-08T23:57:45.076Z" }, + { url = "https://files.pythonhosted.org/packages/4f/cf/5f02bfefffdc6b54e5094d2897bc80efd43050e5b09b576fd85936ee54bf/tiktoken-0.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6a76d53cee2da71ee2731c9caa747398762bda19d7f92665e882fef229cb0b5b", size = 996655, upload-time = "2025-08-08T23:57:46.304Z" }, + { url = "https://files.pythonhosted.org/packages/65/8e/c769b45ef379bc360c9978c4f6914c79fd432400a6733a8afc7ed7b0726a/tiktoken-0.11.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ef72aab3ea240646e642413cb363b73869fed4e604dcfd69eec63dc54d603e8", size = 1128867, upload-time = "2025-08-08T23:57:47.438Z" }, + { url = "https://files.pythonhosted.org/packages/d5/2d/4d77f6feb9292bfdd23d5813e442b3bba883f42d0ac78ef5fdc56873f756/tiktoken-0.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f929255c705efec7a28bf515e29dc74220b2f07544a8c81b8d69e8efc4578bd", size = 1183308, upload-time = "2025-08-08T23:57:48.566Z" }, + { url = "https://files.pythonhosted.org/packages/7a/65/7ff0a65d3bb0fc5a1fb6cc71b03e0f6e71a68c5eea230d1ff1ba3fd6df49/tiktoken-0.11.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:61f1d15822e4404953d499fd1dcc62817a12ae9fb1e4898033ec8fe3915fdf8e", size = 1244301, upload-time = "2025-08-08T23:57:49.642Z" }, + { url = "https://files.pythonhosted.org/packages/f5/6e/5b71578799b72e5bdcef206a214c3ce860d999d579a3b56e74a6c8989ee2/tiktoken-0.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:45927a71ab6643dfd3ef57d515a5db3d199137adf551f66453be098502838b0f", size = 884282, upload-time = "2025-08-08T23:57:50.759Z" }, + { url = "https://files.pythonhosted.org/packages/cc/cd/a9034bcee638716d9310443818d73c6387a6a96db93cbcb0819b77f5b206/tiktoken-0.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a5f3f25ffb152ee7fec78e90a5e5ea5b03b4ea240beed03305615847f7a6ace2", size = 1055339, upload-time = "2025-08-08T23:57:51.802Z" }, + { url = "https://files.pythonhosted.org/packages/f1/91/9922b345f611b4e92581f234e64e9661e1c524875c8eadd513c4b2088472/tiktoken-0.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7dc6e9ad16a2a75b4c4be7208055a1f707c9510541d94d9cc31f7fbdc8db41d8", size = 997080, upload-time = "2025-08-08T23:57:53.442Z" }, + { url = "https://files.pythonhosted.org/packages/d0/9d/49cd047c71336bc4b4af460ac213ec1c457da67712bde59b892e84f1859f/tiktoken-0.11.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a0517634d67a8a48fd4a4ad73930c3022629a85a217d256a6e9b8b47439d1e4", size = 1128501, upload-time = "2025-08-08T23:57:54.808Z" }, + { url = "https://files.pythonhosted.org/packages/52/d5/a0dcdb40dd2ea357e83cb36258967f0ae96f5dd40c722d6e382ceee6bba9/tiktoken-0.11.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fb4effe60574675118b73c6fbfd3b5868e5d7a1f570d6cc0d18724b09ecf318", size = 1182743, upload-time = "2025-08-08T23:57:56.307Z" }, + { url = "https://files.pythonhosted.org/packages/3b/17/a0fc51aefb66b7b5261ca1314afa83df0106b033f783f9a7bcbe8e741494/tiktoken-0.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:94f984c9831fd32688aef4348803b0905d4ae9c432303087bae370dc1381a2b8", size = 1244057, upload-time = "2025-08-08T23:57:57.628Z" }, + { url = "https://files.pythonhosted.org/packages/50/79/bcf350609f3a10f09fe4fc207f132085e497fdd3612f3925ab24d86a0ca0/tiktoken-0.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:2177ffda31dec4023356a441793fed82f7af5291120751dee4d696414f54db0c", size = 883901, upload-time = "2025-08-08T23:57:59.359Z" }, +] + +[[package]] +name = "tokenizers" +version = "0.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/b4/c1ce3699e81977da2ace8b16d2badfd42b060e7d33d75c4ccdbf9dc920fa/tokenizers-0.22.0.tar.gz", hash = "sha256:2e33b98525be8453f355927f3cab312c36cd3e44f4d7e9e97da2fa94d0a49dcb", size = 362771, upload-time = "2025-08-29T10:25:33.914Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/b1/18c13648edabbe66baa85fe266a478a7931ddc0cd1ba618802eb7b8d9865/tokenizers-0.22.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:eaa9620122a3fb99b943f864af95ed14c8dfc0f47afa3b404ac8c16b3f2bb484", size = 3081954, upload-time = "2025-08-29T10:25:24.993Z" }, + { url = "https://files.pythonhosted.org/packages/c2/02/c3c454b641bd7c4f79e4464accfae9e7dfc913a777d2e561e168ae060362/tokenizers-0.22.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:71784b9ab5bf0ff3075bceeb198149d2c5e068549c0d18fe32d06ba0deb63f79", size = 2945644, upload-time = "2025-08-29T10:25:23.405Z" }, + { url = "https://files.pythonhosted.org/packages/55/02/d10185ba2fd8c2d111e124c9d92de398aee0264b35ce433f79fb8472f5d0/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec5b71f668a8076802b0241a42387d48289f25435b86b769ae1837cad4172a17", size = 3254764, upload-time = "2025-08-29T10:25:12.445Z" }, + { url = "https://files.pythonhosted.org/packages/13/89/17514bd7ef4bf5bfff58e2b131cec0f8d5cea2b1c8ffe1050a2c8de88dbb/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ea8562fa7498850d02a16178105b58803ea825b50dc9094d60549a7ed63654bb", size = 3161654, upload-time = "2025-08-29T10:25:15.493Z" }, + { url = "https://files.pythonhosted.org/packages/5a/d8/bac9f3a7ef6dcceec206e3857c3b61bb16c6b702ed7ae49585f5bd85c0ef/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4136e1558a9ef2e2f1de1555dcd573e1cbc4a320c1a06c4107a3d46dc8ac6e4b", size = 3511484, upload-time = "2025-08-29T10:25:20.477Z" }, + { url = "https://files.pythonhosted.org/packages/aa/27/9c9800eb6763683010a4851db4d1802d8cab9cec114c17056eccb4d4a6e0/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf5954de3962a5fd9781dc12048d24a1a6f1f5df038c6e95db328cd22964206", size = 3712829, upload-time = "2025-08-29T10:25:17.154Z" }, + { url = "https://files.pythonhosted.org/packages/10/e3/b1726dbc1f03f757260fa21752e1921445b5bc350389a8314dd3338836db/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8337ca75d0731fc4860e6204cc24bb36a67d9736142aa06ed320943b50b1e7ed", size = 3408934, upload-time = "2025-08-29T10:25:18.76Z" }, + { url = "https://files.pythonhosted.org/packages/d4/61/aeab3402c26874b74bb67a7f2c4b569dde29b51032c5384db592e7b216f4/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a89264e26f63c449d8cded9061adea7b5de53ba2346fc7e87311f7e4117c1cc8", size = 3345585, upload-time = "2025-08-29T10:25:22.08Z" }, + { url = "https://files.pythonhosted.org/packages/bc/d3/498b4a8a8764cce0900af1add0f176ff24f475d4413d55b760b8cdf00893/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:790bad50a1b59d4c21592f9c3cf5e5cf9c3c7ce7e1a23a739f13e01fb1be377a", size = 9322986, upload-time = "2025-08-29T10:25:26.607Z" }, + { url = "https://files.pythonhosted.org/packages/a2/62/92378eb1c2c565837ca3cb5f9569860d132ab9d195d7950c1ea2681dffd0/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:76cf6757c73a10ef10bf06fa937c0ec7393d90432f543f49adc8cab3fb6f26cb", size = 9276630, upload-time = "2025-08-29T10:25:28.349Z" }, + { url = "https://files.pythonhosted.org/packages/eb/f0/342d80457aa1cda7654327460f69db0d69405af1e4c453f4dc6ca7c4a76e/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:1626cb186e143720c62c6c6b5371e62bbc10af60481388c0da89bc903f37ea0c", size = 9547175, upload-time = "2025-08-29T10:25:29.989Z" }, + { url = "https://files.pythonhosted.org/packages/14/84/8aa9b4adfc4fbd09381e20a5bc6aa27040c9c09caa89988c01544e008d18/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:da589a61cbfea18ae267723d6b029b84598dc8ca78db9951d8f5beff72d8507c", size = 9692735, upload-time = "2025-08-29T10:25:32.089Z" }, + { url = "https://files.pythonhosted.org/packages/bf/24/83ee2b1dc76bfe05c3142e7d0ccdfe69f0ad2f1ebf6c726cea7f0874c0d0/tokenizers-0.22.0-cp39-abi3-win32.whl", hash = "sha256:dbf9d6851bddae3e046fedfb166f47743c1c7bd11c640f0691dd35ef0bcad3be", size = 2471915, upload-time = "2025-08-29T10:25:36.411Z" }, + { url = "https://files.pythonhosted.org/packages/d1/9b/0e0bf82214ee20231845b127aa4a8015936ad5a46779f30865d10e404167/tokenizers-0.22.0-cp39-abi3-win_amd64.whl", hash = "sha256:c78174859eeaee96021f248a56c801e36bfb6bd5b067f2e95aa82445ca324f00", size = 2680494, upload-time = "2025-08-29T10:25:35.14Z" }, +] + +[[package]] +name = "tomli" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175, upload-time = "2024-11-27T22:38:36.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077, upload-time = "2024-11-27T22:37:54.956Z" }, + { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429, upload-time = "2024-11-27T22:37:56.698Z" }, + { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067, upload-time = "2024-11-27T22:37:57.63Z" }, + { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030, upload-time = "2024-11-27T22:37:59.344Z" }, + { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898, upload-time = "2024-11-27T22:38:00.429Z" }, + { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894, upload-time = "2024-11-27T22:38:02.094Z" }, + { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319, upload-time = "2024-11-27T22:38:03.206Z" }, + { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273, upload-time = "2024-11-27T22:38:04.217Z" }, + { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310, upload-time = "2024-11-27T22:38:05.908Z" }, + { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309, upload-time = "2024-11-27T22:38:06.812Z" }, + { url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762, upload-time = "2024-11-27T22:38:07.731Z" }, + { url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453, upload-time = "2024-11-27T22:38:09.384Z" }, + { url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486, upload-time = "2024-11-27T22:38:10.329Z" }, + { url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349, upload-time = "2024-11-27T22:38:11.443Z" }, + { url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159, upload-time = "2024-11-27T22:38:13.099Z" }, + { url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243, upload-time = "2024-11-27T22:38:14.766Z" }, + { url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645, upload-time = "2024-11-27T22:38:15.843Z" }, + { url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584, upload-time = "2024-11-27T22:38:17.645Z" }, + { url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875, upload-time = "2024-11-27T22:38:19.159Z" }, + { url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418, upload-time = "2024-11-27T22:38:20.064Z" }, + { url = "https://files.pythonhosted.org/packages/04/90/2ee5f2e0362cb8a0b6499dc44f4d7d48f8fff06d28ba46e6f1eaa61a1388/tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7", size = 132708, upload-time = "2024-11-27T22:38:21.659Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ec/46b4108816de6b385141f082ba99e315501ccd0a2ea23db4a100dd3990ea/tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c", size = 123582, upload-time = "2024-11-27T22:38:22.693Z" }, + { url = "https://files.pythonhosted.org/packages/a0/bd/b470466d0137b37b68d24556c38a0cc819e8febe392d5b199dcd7f578365/tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13", size = 232543, upload-time = "2024-11-27T22:38:24.367Z" }, + { url = "https://files.pythonhosted.org/packages/d9/e5/82e80ff3b751373f7cead2815bcbe2d51c895b3c990686741a8e56ec42ab/tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281", size = 241691, upload-time = "2024-11-27T22:38:26.081Z" }, + { url = "https://files.pythonhosted.org/packages/05/7e/2a110bc2713557d6a1bfb06af23dd01e7dde52b6ee7dadc589868f9abfac/tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272", size = 251170, upload-time = "2024-11-27T22:38:27.921Z" }, + { url = "https://files.pythonhosted.org/packages/64/7b/22d713946efe00e0adbcdfd6d1aa119ae03fd0b60ebed51ebb3fa9f5a2e5/tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140", size = 236530, upload-time = "2024-11-27T22:38:29.591Z" }, + { url = "https://files.pythonhosted.org/packages/38/31/3a76f67da4b0cf37b742ca76beaf819dca0ebef26d78fc794a576e08accf/tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2", size = 258666, upload-time = "2024-11-27T22:38:30.639Z" }, + { url = "https://files.pythonhosted.org/packages/07/10/5af1293da642aded87e8a988753945d0cf7e00a9452d3911dd3bb354c9e2/tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744", size = 243954, upload-time = "2024-11-27T22:38:31.702Z" }, + { url = "https://files.pythonhosted.org/packages/5b/b9/1ed31d167be802da0fc95020d04cd27b7d7065cc6fbefdd2f9186f60d7bd/tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec", size = 98724, upload-time = "2024-11-27T22:38:32.837Z" }, + { url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383, upload-time = "2024-11-27T22:38:34.455Z" }, + { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, +] + +[[package]] +name = "torch" +version = "2.8.0+cu128" +source = { registry = "https://download.pytorch.org/whl/cu128" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools", marker = "python_full_version >= '3.12'" }, + { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:43938e9a174c90e5eb9e906532b2f1e21532bbfa5a61b65193b4f54714d34f9e" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:039b9dcdd6bdbaa10a8a5cd6be22c4cb3e3589a341e5f904cbb571ca28f55bed" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:34c55443aafd31046a7963b63d30bc3b628ee4a704f826796c865fdfd05bb596" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4354fc05bb79b208d6995a04ca1ceef6a9547b1c4334435574353d381c55087c" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:0ad925202387f4e7314302a1b4f8860fa824357f9b1466d7992bf276370ebcff" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:3a852369a38dec343d45ecd0bc3660f79b88a23e0c878d18707f7c13bf49538f" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:9e20646802b7fc295c1f8b45fefcfc9fb2e4ec9cbe8593443cd2b9cc307c8405" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4295a22d69408e93d25f51e8d5d579345b6b802383e9414b0f3853ed433d53ae" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:970b4f4661fa7b44f6a7e6df65de7fc4a6fff2af610dc415c1d695ca5f1f37d2" }, +] + +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + +[[package]] +name = "triton" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools", marker = "sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, + { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" }, + { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" }, + { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" }, + { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780, upload-time = "2025-07-30T19:58:51.171Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/b1/0c11f5058406b3af7609f121aaa6b609744687f1d158b3c3a5bf4cc94238/typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28", size = 75726, upload-time = "2025-05-21T18:55:23.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, +] + +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] + +[[package]] +name = "uvicorn" +version = "0.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/5e/f0cd46063a02fd8515f0e880c37d2657845b7306c16ce6c4ffc44afd9036/uvicorn-0.36.0.tar.gz", hash = "sha256:527dc68d77819919d90a6b267be55f0e76704dca829d34aea9480be831a9b9d9", size = 80032, upload-time = "2025-09-20T01:07:14.418Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/06/5cc0542b47c0338c1cb676b348e24a1c29acabc81000bced518231dded6f/uvicorn-0.36.0-py3-none-any.whl", hash = "sha256:6bb4ba67f16024883af8adf13aba3a9919e415358604ce46780d3f9bdc36d731", size = 67675, upload-time = "2025-09-20T01:07:12.984Z" }, +] + +[[package]] +name = "wandb" +version = "0.21.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "gitpython" }, + { name = "packaging" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/84/af6ccdf95e56f15aceb360e437fbfcca3dc91ad8ca335fe482083e29f7a5/wandb-0.21.3.tar.gz", hash = "sha256:031e24e2aad0ce735dfdcc74baf2f2c12c106f500ed24798de6ef9b9e63bb432", size = 40146972, upload-time = "2025-08-30T18:21:55.138Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/e8/b5bfbbc7f76c11fd0665b92be8a38c6a83b27f353552233b9959b21be488/wandb-0.21.3-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:f85bac45b4482742ec9ff190af38eb00a877ddeb4875475e7e487dc19300ff03", size = 18820209, upload-time = "2025-08-30T18:21:33.47Z" }, + { url = "https://files.pythonhosted.org/packages/59/a3/03f0fcde49609df1cb3a382fb5053f601b88da448bcd415ed7f75272eee7/wandb-0.21.3-py3-none-macosx_12_0_arm64.whl", hash = "sha256:8a2b3ba419b91d47edead2755f04cef54f9e3c4496ee0c9854c3cfeff4216dd3", size = 18310636, upload-time = "2025-08-30T18:21:37.405Z" }, + { url = "https://files.pythonhosted.org/packages/1d/c3/d6048db30ff2e3c67089ba0e94878572fd26137b146f8e3b27bbdf428b31/wandb-0.21.3-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:35a1972881f3b85755befab004118234593792a9f05e07fd6345780172f4420e", size = 19053277, upload-time = "2025-08-30T18:21:39.389Z" }, + { url = "https://files.pythonhosted.org/packages/ea/7f/805c3d2fa9e3b8b6bf2bc534887c9ed97bdf22007ca8ba59424a1c8bb360/wandb-0.21.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d9cf8588cb090a2a41f589037fda72c57c9e23edfbd2ad829e575f1305d942c", size = 18130850, upload-time = "2025-08-30T18:21:41.573Z" }, + { url = "https://files.pythonhosted.org/packages/5b/af/a3252e5afac98a036f83c65ec92cadf6677ccdaacbbb2151da29f694d136/wandb-0.21.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff24b6b8e0f9da840b6bd5c7f60b0a5507bd998db40c9c2d476f9a340bec8ed", size = 19570305, upload-time = "2025-08-30T18:21:43.811Z" }, + { url = "https://files.pythonhosted.org/packages/4d/f9/4404b5a24bfd4ba027c19d30152b0fc7ebca8c49b202dee6ecb7f316082c/wandb-0.21.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4975dec19e2b343e23ed6e60f7e1290120553719f82e87a22205bede758416ad", size = 18135806, upload-time = "2025-08-30T18:21:46.211Z" }, + { url = "https://files.pythonhosted.org/packages/ff/32/9580f42899e54f3d0b4ea619b6f6a54980a4e36fd0675d58c09f0a08d3f6/wandb-0.21.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:514a0aad40ecc0bdb757b1dc86e4ac98f61d2d760445b6e1f555291562320f2d", size = 19646760, upload-time = "2025-08-30T18:21:48.768Z" }, + { url = "https://files.pythonhosted.org/packages/75/d3/faa6ddb792a158c154fb704b25c96d0478e71eabf96e3f17529fb23b6894/wandb-0.21.3-py3-none-win32.whl", hash = "sha256:45aa3d8ad53c6ee06f37490d7a329ed7d0f5ca4dbd5d05bb0c01d5da22f14691", size = 18709408, upload-time = "2025-08-30T18:21:50.859Z" }, + { url = "https://files.pythonhosted.org/packages/d8/2d/7ef56e25f78786e59fefd9b19867c325f9686317d9f7b93b5cb340360a3e/wandb-0.21.3-py3-none-win_amd64.whl", hash = "sha256:56d5a5697766f552a9933d8c6a564202194768eb0389bd5f9fe9a99cd4cee41e", size = 18709411, upload-time = "2025-08-30T18:21:52.874Z" }, +] + +[[package]] +name = "xxhash" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/00/5e/d6e5258d69df8b4ed8c83b6664f2b47d30d2dec551a29ad72a6c69eafd31/xxhash-3.5.0.tar.gz", hash = "sha256:84f2caddf951c9cbf8dc2e22a89d4ccf5d86391ac6418fe81e3c67d0cf60b45f", size = 84241, upload-time = "2024-08-17T09:20:38.972Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/8a/0e9feca390d512d293afd844d31670e25608c4a901e10202aa98785eab09/xxhash-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ece616532c499ee9afbb83078b1b952beffef121d989841f7f4b3dc5ac0fd212", size = 31970, upload-time = "2024-08-17T09:17:35.675Z" }, + { url = "https://files.pythonhosted.org/packages/16/e6/be5aa49580cd064a18200ab78e29b88b1127e1a8c7955eb8ecf81f2626eb/xxhash-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3171f693dbc2cef6477054a665dc255d996646b4023fe56cb4db80e26f4cc520", size = 30801, upload-time = "2024-08-17T09:17:37.353Z" }, + { url = "https://files.pythonhosted.org/packages/20/ee/b8a99ebbc6d1113b3a3f09e747fa318c3cde5b04bd9c197688fadf0eeae8/xxhash-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c5d3e570ef46adaf93fc81b44aca6002b5a4d8ca11bd0580c07eac537f36680", size = 220927, upload-time = "2024-08-17T09:17:38.835Z" }, + { url = "https://files.pythonhosted.org/packages/58/62/15d10582ef159283a5c2b47f6d799fc3303fe3911d5bb0bcc820e1ef7ff4/xxhash-3.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7cb29a034301e2982df8b1fe6328a84f4b676106a13e9135a0d7e0c3e9f806da", size = 200360, upload-time = "2024-08-17T09:17:40.851Z" }, + { url = "https://files.pythonhosted.org/packages/23/41/61202663ea9b1bd8e53673b8ec9e2619989353dba8cfb68e59a9cbd9ffe3/xxhash-3.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0d307d27099bb0cbeea7260eb39ed4fdb99c5542e21e94bb6fd29e49c57a23", size = 428528, upload-time = "2024-08-17T09:17:42.545Z" }, + { url = "https://files.pythonhosted.org/packages/f2/07/d9a3059f702dec5b3b703737afb6dda32f304f6e9da181a229dafd052c29/xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0342aafd421795d740e514bc9858ebddfc705a75a8c5046ac56d85fe97bf196", size = 194149, upload-time = "2024-08-17T09:17:44.361Z" }, + { url = "https://files.pythonhosted.org/packages/eb/58/27caadf78226ecf1d62dbd0c01d152ed381c14c1ee4ad01f0d460fc40eac/xxhash-3.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dbbd9892c5ebffeca1ed620cf0ade13eb55a0d8c84e0751a6653adc6ac40d0c", size = 207703, upload-time = "2024-08-17T09:17:46.656Z" }, + { url = "https://files.pythonhosted.org/packages/b1/08/32d558ce23e1e068453c39aed7b3c1cdc690c177873ec0ca3a90d5808765/xxhash-3.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4cc2d67fdb4d057730c75a64c5923abfa17775ae234a71b0200346bfb0a7f482", size = 216255, upload-time = "2024-08-17T09:17:48.031Z" }, + { url = "https://files.pythonhosted.org/packages/3f/d4/2b971e2d2b0a61045f842b622ef11e94096cf1f12cd448b6fd426e80e0e2/xxhash-3.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ec28adb204b759306a3d64358a5e5c07d7b1dd0ccbce04aa76cb9377b7b70296", size = 202744, upload-time = "2024-08-17T09:17:50.045Z" }, + { url = "https://files.pythonhosted.org/packages/19/ae/6a6438864a8c4c39915d7b65effd85392ebe22710412902487e51769146d/xxhash-3.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1328f6d8cca2b86acb14104e381225a3d7b42c92c4b86ceae814e5c400dbb415", size = 210115, upload-time = "2024-08-17T09:17:51.834Z" }, + { url = "https://files.pythonhosted.org/packages/48/7d/b3c27c27d1fc868094d02fe4498ccce8cec9fcc591825c01d6bcb0b4fc49/xxhash-3.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8d47ebd9f5d9607fd039c1fbf4994e3b071ea23eff42f4ecef246ab2b7334198", size = 414247, upload-time = "2024-08-17T09:17:53.094Z" }, + { url = "https://files.pythonhosted.org/packages/a1/05/918f9e7d2fbbd334b829997045d341d6239b563c44e683b9a7ef8fe50f5d/xxhash-3.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b96d559e0fcddd3343c510a0fe2b127fbff16bf346dd76280b82292567523442", size = 191419, upload-time = "2024-08-17T09:17:54.906Z" }, + { url = "https://files.pythonhosted.org/packages/08/29/dfe393805b2f86bfc47c290b275f0b7c189dc2f4e136fd4754f32eb18a8d/xxhash-3.5.0-cp310-cp310-win32.whl", hash = "sha256:61c722ed8d49ac9bc26c7071eeaa1f6ff24053d553146d5df031802deffd03da", size = 30114, upload-time = "2024-08-17T09:17:56.566Z" }, + { url = "https://files.pythonhosted.org/packages/7b/d7/aa0b22c4ebb7c3ccb993d4c565132abc641cd11164f8952d89eb6a501909/xxhash-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:9bed5144c6923cc902cd14bb8963f2d5e034def4486ab0bbe1f58f03f042f9a9", size = 30003, upload-time = "2024-08-17T09:17:57.596Z" }, + { url = "https://files.pythonhosted.org/packages/69/12/f969b81541ee91b55f1ce469d7ab55079593c80d04fd01691b550e535000/xxhash-3.5.0-cp310-cp310-win_arm64.whl", hash = "sha256:893074d651cf25c1cc14e3bea4fceefd67f2921b1bb8e40fcfeba56820de80c6", size = 26773, upload-time = "2024-08-17T09:17:59.169Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/afed0f131fbda960ff15eee7f304fa0eeb2d58770fade99897984852ef23/xxhash-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02c2e816896dc6f85922ced60097bcf6f008dedfc5073dcba32f9c8dd786f3c1", size = 31969, upload-time = "2024-08-17T09:18:00.852Z" }, + { url = "https://files.pythonhosted.org/packages/8c/0c/7c3bc6d87e5235672fcc2fb42fd5ad79fe1033925f71bf549ee068c7d1ca/xxhash-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6027dcd885e21581e46d3c7f682cfb2b870942feeed58a21c29583512c3f09f8", size = 30800, upload-time = "2024-08-17T09:18:01.863Z" }, + { url = "https://files.pythonhosted.org/packages/04/9e/01067981d98069eec1c20201f8c145367698e9056f8bc295346e4ea32dd1/xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1308fa542bbdbf2fa85e9e66b1077eea3a88bef38ee8a06270b4298a7a62a166", size = 221566, upload-time = "2024-08-17T09:18:03.461Z" }, + { url = "https://files.pythonhosted.org/packages/d4/09/d4996de4059c3ce5342b6e1e6a77c9d6c91acce31f6ed979891872dd162b/xxhash-3.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c28b2fdcee797e1c1961cd3bcd3d545cab22ad202c846235197935e1df2f8ef7", size = 201214, upload-time = "2024-08-17T09:18:05.616Z" }, + { url = "https://files.pythonhosted.org/packages/62/f5/6d2dc9f8d55a7ce0f5e7bfef916e67536f01b85d32a9fbf137d4cadbee38/xxhash-3.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:924361811732ddad75ff23e90efd9ccfda4f664132feecb90895bade6a1b4623", size = 429433, upload-time = "2024-08-17T09:18:06.957Z" }, + { url = "https://files.pythonhosted.org/packages/d9/72/9256303f10e41ab004799a4aa74b80b3c5977d6383ae4550548b24bd1971/xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89997aa1c4b6a5b1e5b588979d1da048a3c6f15e55c11d117a56b75c84531f5a", size = 194822, upload-time = "2024-08-17T09:18:08.331Z" }, + { url = "https://files.pythonhosted.org/packages/34/92/1a3a29acd08248a34b0e6a94f4e0ed9b8379a4ff471f1668e4dce7bdbaa8/xxhash-3.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:685c4f4e8c59837de103344eb1c8a3851f670309eb5c361f746805c5471b8c88", size = 208538, upload-time = "2024-08-17T09:18:10.332Z" }, + { url = "https://files.pythonhosted.org/packages/53/ad/7fa1a109663366de42f724a1cdb8e796a260dbac45047bce153bc1e18abf/xxhash-3.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbd2ecfbfee70bc1a4acb7461fa6af7748ec2ab08ac0fa298f281c51518f982c", size = 216953, upload-time = "2024-08-17T09:18:11.707Z" }, + { url = "https://files.pythonhosted.org/packages/35/02/137300e24203bf2b2a49b48ce898ecce6fd01789c0fcd9c686c0a002d129/xxhash-3.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:25b5a51dc3dfb20a10833c8eee25903fd2e14059e9afcd329c9da20609a307b2", size = 203594, upload-time = "2024-08-17T09:18:13.799Z" }, + { url = "https://files.pythonhosted.org/packages/23/03/aeceb273933d7eee248c4322b98b8e971f06cc3880e5f7602c94e5578af5/xxhash-3.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a8fb786fb754ef6ff8c120cb96629fb518f8eb5a61a16aac3a979a9dbd40a084", size = 210971, upload-time = "2024-08-17T09:18:15.824Z" }, + { url = "https://files.pythonhosted.org/packages/e3/64/ed82ec09489474cbb35c716b189ddc1521d8b3de12b1b5ab41ce7f70253c/xxhash-3.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a905ad00ad1e1c34fe4e9d7c1d949ab09c6fa90c919860c1534ff479f40fd12d", size = 415050, upload-time = "2024-08-17T09:18:17.142Z" }, + { url = "https://files.pythonhosted.org/packages/71/43/6db4c02dcb488ad4e03bc86d70506c3d40a384ee73c9b5c93338eb1f3c23/xxhash-3.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:963be41bcd49f53af6d795f65c0da9b4cc518c0dd9c47145c98f61cb464f4839", size = 192216, upload-time = "2024-08-17T09:18:18.779Z" }, + { url = "https://files.pythonhosted.org/packages/22/6d/db4abec29e7a567455344433d095fdb39c97db6955bb4a2c432e486b4d28/xxhash-3.5.0-cp311-cp311-win32.whl", hash = "sha256:109b436096d0a2dd039c355fa3414160ec4d843dfecc64a14077332a00aeb7da", size = 30120, upload-time = "2024-08-17T09:18:20.009Z" }, + { url = "https://files.pythonhosted.org/packages/52/1c/fa3b61c0cf03e1da4767213672efe186b1dfa4fc901a4a694fb184a513d1/xxhash-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b702f806693201ad6c0a05ddbbe4c8f359626d0b3305f766077d51388a6bac58", size = 30003, upload-time = "2024-08-17T09:18:21.052Z" }, + { url = "https://files.pythonhosted.org/packages/6b/8e/9e6fc572acf6e1cc7ccb01973c213f895cb8668a9d4c2b58a99350da14b7/xxhash-3.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:c4dcb4120d0cc3cc448624147dba64e9021b278c63e34a38789b688fd0da9bf3", size = 26777, upload-time = "2024-08-17T09:18:22.809Z" }, + { url = "https://files.pythonhosted.org/packages/07/0e/1bfce2502c57d7e2e787600b31c83535af83746885aa1a5f153d8c8059d6/xxhash-3.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:14470ace8bd3b5d51318782cd94e6f94431974f16cb3b8dc15d52f3b69df8e00", size = 31969, upload-time = "2024-08-17T09:18:24.025Z" }, + { url = "https://files.pythonhosted.org/packages/3f/d6/8ca450d6fe5b71ce521b4e5db69622383d039e2b253e9b2f24f93265b52c/xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:59aa1203de1cb96dbeab595ded0ad0c0056bb2245ae11fac11c0ceea861382b9", size = 30787, upload-time = "2024-08-17T09:18:25.318Z" }, + { url = "https://files.pythonhosted.org/packages/5b/84/de7c89bc6ef63d750159086a6ada6416cc4349eab23f76ab870407178b93/xxhash-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08424f6648526076e28fae6ea2806c0a7d504b9ef05ae61d196d571e5c879c84", size = 220959, upload-time = "2024-08-17T09:18:26.518Z" }, + { url = "https://files.pythonhosted.org/packages/fe/86/51258d3e8a8545ff26468c977101964c14d56a8a37f5835bc0082426c672/xxhash-3.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61a1ff00674879725b194695e17f23d3248998b843eb5e933007ca743310f793", size = 200006, upload-time = "2024-08-17T09:18:27.905Z" }, + { url = "https://files.pythonhosted.org/packages/02/0a/96973bd325412feccf23cf3680fd2246aebf4b789122f938d5557c54a6b2/xxhash-3.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2f2c61bee5844d41c3eb015ac652a0229e901074951ae48581d58bfb2ba01be", size = 428326, upload-time = "2024-08-17T09:18:29.335Z" }, + { url = "https://files.pythonhosted.org/packages/11/a7/81dba5010f7e733de88af9555725146fc133be97ce36533867f4c7e75066/xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d32a592cac88d18cc09a89172e1c32d7f2a6e516c3dfde1b9adb90ab5df54a6", size = 194380, upload-time = "2024-08-17T09:18:30.706Z" }, + { url = "https://files.pythonhosted.org/packages/fb/7d/f29006ab398a173f4501c0e4977ba288f1c621d878ec217b4ff516810c04/xxhash-3.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70dabf941dede727cca579e8c205e61121afc9b28516752fd65724be1355cc90", size = 207934, upload-time = "2024-08-17T09:18:32.133Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6e/6e88b8f24612510e73d4d70d9b0c7dff62a2e78451b9f0d042a5462c8d03/xxhash-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e5d0ddaca65ecca9c10dcf01730165fd858533d0be84c75c327487c37a906a27", size = 216301, upload-time = "2024-08-17T09:18:33.474Z" }, + { url = "https://files.pythonhosted.org/packages/af/51/7862f4fa4b75a25c3b4163c8a873f070532fe5f2d3f9b3fc869c8337a398/xxhash-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e5b5e16c5a480fe5f59f56c30abdeba09ffd75da8d13f6b9b6fd224d0b4d0a2", size = 203351, upload-time = "2024-08-17T09:18:34.889Z" }, + { url = "https://files.pythonhosted.org/packages/22/61/8d6a40f288f791cf79ed5bb113159abf0c81d6efb86e734334f698eb4c59/xxhash-3.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149b7914451eb154b3dfaa721315117ea1dac2cc55a01bfbd4df7c68c5dd683d", size = 210294, upload-time = "2024-08-17T09:18:36.355Z" }, + { url = "https://files.pythonhosted.org/packages/17/02/215c4698955762d45a8158117190261b2dbefe9ae7e5b906768c09d8bc74/xxhash-3.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:eade977f5c96c677035ff39c56ac74d851b1cca7d607ab3d8f23c6b859379cab", size = 414674, upload-time = "2024-08-17T09:18:38.536Z" }, + { url = "https://files.pythonhosted.org/packages/31/5c/b7a8db8a3237cff3d535261325d95de509f6a8ae439a5a7a4ffcff478189/xxhash-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa9f547bd98f5553d03160967866a71056a60960be00356a15ecc44efb40ba8e", size = 192022, upload-time = "2024-08-17T09:18:40.138Z" }, + { url = "https://files.pythonhosted.org/packages/78/e3/dd76659b2811b3fd06892a8beb850e1996b63e9235af5a86ea348f053e9e/xxhash-3.5.0-cp312-cp312-win32.whl", hash = "sha256:f7b58d1fd3551b8c80a971199543379be1cee3d0d409e1f6d8b01c1a2eebf1f8", size = 30170, upload-time = "2024-08-17T09:18:42.163Z" }, + { url = "https://files.pythonhosted.org/packages/d9/6b/1c443fe6cfeb4ad1dcf231cdec96eb94fb43d6498b4469ed8b51f8b59a37/xxhash-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0cafd3a2af231b4e113fba24a65d7922af91aeb23774a8b78228e6cd785e3e", size = 30040, upload-time = "2024-08-17T09:18:43.699Z" }, + { url = "https://files.pythonhosted.org/packages/0f/eb/04405305f290173acc0350eba6d2f1a794b57925df0398861a20fbafa415/xxhash-3.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:586886c7e89cb9828bcd8a5686b12e161368e0064d040e225e72607b43858ba2", size = 26796, upload-time = "2024-08-17T09:18:45.29Z" }, + { url = "https://files.pythonhosted.org/packages/c9/b8/e4b3ad92d249be5c83fa72916c9091b0965cb0faeff05d9a0a3870ae6bff/xxhash-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:37889a0d13b0b7d739cfc128b1c902f04e32de17b33d74b637ad42f1c55101f6", size = 31795, upload-time = "2024-08-17T09:18:46.813Z" }, + { url = "https://files.pythonhosted.org/packages/fc/d8/b3627a0aebfbfa4c12a41e22af3742cf08c8ea84f5cc3367b5de2d039cce/xxhash-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97a662338797c660178e682f3bc180277b9569a59abfb5925e8620fba00b9fc5", size = 30792, upload-time = "2024-08-17T09:18:47.862Z" }, + { url = "https://files.pythonhosted.org/packages/c3/cc/762312960691da989c7cd0545cb120ba2a4148741c6ba458aa723c00a3f8/xxhash-3.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f85e0108d51092bdda90672476c7d909c04ada6923c14ff9d913c4f7dc8a3bc", size = 220950, upload-time = "2024-08-17T09:18:49.06Z" }, + { url = "https://files.pythonhosted.org/packages/fe/e9/cc266f1042c3c13750e86a535496b58beb12bf8c50a915c336136f6168dc/xxhash-3.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2fd827b0ba763ac919440042302315c564fdb797294d86e8cdd4578e3bc7f3", size = 199980, upload-time = "2024-08-17T09:18:50.445Z" }, + { url = "https://files.pythonhosted.org/packages/bf/85/a836cd0dc5cc20376de26b346858d0ac9656f8f730998ca4324921a010b9/xxhash-3.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82085c2abec437abebf457c1d12fccb30cc8b3774a0814872511f0f0562c768c", size = 428324, upload-time = "2024-08-17T09:18:51.988Z" }, + { url = "https://files.pythonhosted.org/packages/b4/0e/15c243775342ce840b9ba34aceace06a1148fa1630cd8ca269e3223987f5/xxhash-3.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07fda5de378626e502b42b311b049848c2ef38784d0d67b6f30bb5008642f8eb", size = 194370, upload-time = "2024-08-17T09:18:54.164Z" }, + { url = "https://files.pythonhosted.org/packages/87/a1/b028bb02636dfdc190da01951d0703b3d904301ed0ef6094d948983bef0e/xxhash-3.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c279f0d2b34ef15f922b77966640ade58b4ccdfef1c4d94b20f2a364617a493f", size = 207911, upload-time = "2024-08-17T09:18:55.509Z" }, + { url = "https://files.pythonhosted.org/packages/80/d5/73c73b03fc0ac73dacf069fdf6036c9abad82de0a47549e9912c955ab449/xxhash-3.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89e66ceed67b213dec5a773e2f7a9e8c58f64daeb38c7859d8815d2c89f39ad7", size = 216352, upload-time = "2024-08-17T09:18:57.073Z" }, + { url = "https://files.pythonhosted.org/packages/b6/2a/5043dba5ddbe35b4fe6ea0a111280ad9c3d4ba477dd0f2d1fe1129bda9d0/xxhash-3.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bcd51708a633410737111e998ceb3b45d3dbc98c0931f743d9bb0a209033a326", size = 203410, upload-time = "2024-08-17T09:18:58.54Z" }, + { url = "https://files.pythonhosted.org/packages/a2/b2/9a8ded888b7b190aed75b484eb5c853ddd48aa2896e7b59bbfbce442f0a1/xxhash-3.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3ff2c0a34eae7df88c868be53a8dd56fbdf592109e21d4bfa092a27b0bf4a7bf", size = 210322, upload-time = "2024-08-17T09:18:59.943Z" }, + { url = "https://files.pythonhosted.org/packages/98/62/440083fafbc917bf3e4b67c2ade621920dd905517e85631c10aac955c1d2/xxhash-3.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4e28503dccc7d32e0b9817aa0cbfc1f45f563b2c995b7a66c4c8a0d232e840c7", size = 414725, upload-time = "2024-08-17T09:19:01.332Z" }, + { url = "https://files.pythonhosted.org/packages/75/db/009206f7076ad60a517e016bb0058381d96a007ce3f79fa91d3010f49cc2/xxhash-3.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a6c50017518329ed65a9e4829154626f008916d36295b6a3ba336e2458824c8c", size = 192070, upload-time = "2024-08-17T09:19:03.007Z" }, + { url = "https://files.pythonhosted.org/packages/1f/6d/c61e0668943a034abc3a569cdc5aeae37d686d9da7e39cf2ed621d533e36/xxhash-3.5.0-cp313-cp313-win32.whl", hash = "sha256:53a068fe70301ec30d868ece566ac90d873e3bb059cf83c32e76012c889b8637", size = 30172, upload-time = "2024-08-17T09:19:04.355Z" }, + { url = "https://files.pythonhosted.org/packages/96/14/8416dce965f35e3d24722cdf79361ae154fa23e2ab730e5323aa98d7919e/xxhash-3.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:80babcc30e7a1a484eab952d76a4f4673ff601f54d5142c26826502740e70b43", size = 30041, upload-time = "2024-08-17T09:19:05.435Z" }, + { url = "https://files.pythonhosted.org/packages/27/ee/518b72faa2073f5aa8e3262408d284892cb79cf2754ba0c3a5870645ef73/xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b", size = 26801, upload-time = "2024-08-17T09:19:06.547Z" }, + { url = "https://files.pythonhosted.org/packages/ab/9a/233606bada5bd6f50b2b72c45de3d9868ad551e83893d2ac86dc7bb8553a/xxhash-3.5.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2014c5b3ff15e64feecb6b713af12093f75b7926049e26a580e94dcad3c73d8c", size = 29732, upload-time = "2024-08-17T09:20:11.175Z" }, + { url = "https://files.pythonhosted.org/packages/0c/67/f75276ca39e2c6604e3bee6c84e9db8a56a4973fde9bf35989787cf6e8aa/xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab81ef75003eda96239a23eda4e4543cedc22e34c373edcaf744e721a163986", size = 36214, upload-time = "2024-08-17T09:20:12.335Z" }, + { url = "https://files.pythonhosted.org/packages/0f/f8/f6c61fd794229cc3848d144f73754a0c107854372d7261419dcbbd286299/xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e2febf914ace002132aa09169cc572e0d8959d0f305f93d5828c4836f9bc5a6", size = 32020, upload-time = "2024-08-17T09:20:13.537Z" }, + { url = "https://files.pythonhosted.org/packages/79/d3/c029c99801526f859e6b38d34ab87c08993bf3dcea34b11275775001638a/xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d3a10609c51da2a1c0ea0293fc3968ca0a18bd73838455b5bca3069d7f8e32b", size = 40515, upload-time = "2024-08-17T09:20:14.669Z" }, + { url = "https://files.pythonhosted.org/packages/62/e3/bef7b82c1997579c94de9ac5ea7626d01ae5858aa22bf4fcb38bf220cb3e/xxhash-3.5.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5a74f23335b9689b66eb6dbe2a931a88fcd7a4c2cc4b1cb0edba8ce381c7a1da", size = 30064, upload-time = "2024-08-17T09:20:15.925Z" }, +] + +[[package]] +name = "yarl" +version = "1.20.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "multidict" }, + { name = "propcache" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3c/fb/efaa23fa4e45537b827620f04cf8f3cd658b76642205162e072703a5b963/yarl-1.20.1.tar.gz", hash = "sha256:d017a4997ee50c91fd5466cef416231bb82177b93b029906cefc542ce14c35ac", size = 186428, upload-time = "2025-06-10T00:46:09.923Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/65/7fed0d774abf47487c64be14e9223749468922817b5e8792b8a64792a1bb/yarl-1.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6032e6da6abd41e4acda34d75a816012717000fa6839f37124a47fcefc49bec4", size = 132910, upload-time = "2025-06-10T00:42:31.108Z" }, + { url = "https://files.pythonhosted.org/packages/8a/7b/988f55a52da99df9e56dc733b8e4e5a6ae2090081dc2754fc8fd34e60aa0/yarl-1.20.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2c7b34d804b8cf9b214f05015c4fee2ebe7ed05cf581e7192c06555c71f4446a", size = 90644, upload-time = "2025-06-10T00:42:33.851Z" }, + { url = "https://files.pythonhosted.org/packages/f7/de/30d98f03e95d30c7e3cc093759982d038c8833ec2451001d45ef4854edc1/yarl-1.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c869f2651cc77465f6cd01d938d91a11d9ea5d798738c1dc077f3de0b5e5fed", size = 89322, upload-time = "2025-06-10T00:42:35.688Z" }, + { url = "https://files.pythonhosted.org/packages/e0/7a/f2f314f5ebfe9200724b0b748de2186b927acb334cf964fd312eb86fc286/yarl-1.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62915e6688eb4d180d93840cda4110995ad50c459bf931b8b3775b37c264af1e", size = 323786, upload-time = "2025-06-10T00:42:37.817Z" }, + { url = "https://files.pythonhosted.org/packages/15/3f/718d26f189db96d993d14b984ce91de52e76309d0fd1d4296f34039856aa/yarl-1.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:41ebd28167bc6af8abb97fec1a399f412eec5fd61a3ccbe2305a18b84fb4ca73", size = 319627, upload-time = "2025-06-10T00:42:39.937Z" }, + { url = "https://files.pythonhosted.org/packages/a5/76/8fcfbf5fa2369157b9898962a4a7d96764b287b085b5b3d9ffae69cdefd1/yarl-1.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:21242b4288a6d56f04ea193adde174b7e347ac46ce6bc84989ff7c1b1ecea84e", size = 339149, upload-time = "2025-06-10T00:42:42.627Z" }, + { url = "https://files.pythonhosted.org/packages/3c/95/d7fc301cc4661785967acc04f54a4a42d5124905e27db27bb578aac49b5c/yarl-1.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bea21cdae6c7eb02ba02a475f37463abfe0a01f5d7200121b03e605d6a0439f8", size = 333327, upload-time = "2025-06-10T00:42:44.842Z" }, + { url = "https://files.pythonhosted.org/packages/65/94/e21269718349582eee81efc5c1c08ee71c816bfc1585b77d0ec3f58089eb/yarl-1.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f8a891e4a22a89f5dde7862994485e19db246b70bb288d3ce73a34422e55b23", size = 326054, upload-time = "2025-06-10T00:42:47.149Z" }, + { url = "https://files.pythonhosted.org/packages/32/ae/8616d1f07853704523519f6131d21f092e567c5af93de7e3e94b38d7f065/yarl-1.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd803820d44c8853a109a34e3660e5a61beae12970da479cf44aa2954019bf70", size = 315035, upload-time = "2025-06-10T00:42:48.852Z" }, + { url = "https://files.pythonhosted.org/packages/48/aa/0ace06280861ef055855333707db5e49c6e3a08840a7ce62682259d0a6c0/yarl-1.20.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b982fa7f74c80d5c0c7b5b38f908971e513380a10fecea528091405f519b9ebb", size = 338962, upload-time = "2025-06-10T00:42:51.024Z" }, + { url = "https://files.pythonhosted.org/packages/20/52/1e9d0e6916f45a8fb50e6844f01cb34692455f1acd548606cbda8134cd1e/yarl-1.20.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:33f29ecfe0330c570d997bcf1afd304377f2e48f61447f37e846a6058a4d33b2", size = 335399, upload-time = "2025-06-10T00:42:53.007Z" }, + { url = "https://files.pythonhosted.org/packages/f2/65/60452df742952c630e82f394cd409de10610481d9043aa14c61bf846b7b1/yarl-1.20.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:835ab2cfc74d5eb4a6a528c57f05688099da41cf4957cf08cad38647e4a83b30", size = 338649, upload-time = "2025-06-10T00:42:54.964Z" }, + { url = "https://files.pythonhosted.org/packages/7b/f5/6cd4ff38dcde57a70f23719a838665ee17079640c77087404c3d34da6727/yarl-1.20.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:46b5e0ccf1943a9a6e766b2c2b8c732c55b34e28be57d8daa2b3c1d1d4009309", size = 358563, upload-time = "2025-06-10T00:42:57.28Z" }, + { url = "https://files.pythonhosted.org/packages/d1/90/c42eefd79d0d8222cb3227bdd51b640c0c1d0aa33fe4cc86c36eccba77d3/yarl-1.20.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:df47c55f7d74127d1b11251fe6397d84afdde0d53b90bedb46a23c0e534f9d24", size = 357609, upload-time = "2025-06-10T00:42:59.055Z" }, + { url = "https://files.pythonhosted.org/packages/03/c8/cea6b232cb4617514232e0f8a718153a95b5d82b5290711b201545825532/yarl-1.20.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76d12524d05841276b0e22573f28d5fbcb67589836772ae9244d90dd7d66aa13", size = 350224, upload-time = "2025-06-10T00:43:01.248Z" }, + { url = "https://files.pythonhosted.org/packages/ce/a3/eaa0ab9712f1f3d01faf43cf6f1f7210ce4ea4a7e9b28b489a2261ca8db9/yarl-1.20.1-cp310-cp310-win32.whl", hash = "sha256:6c4fbf6b02d70e512d7ade4b1f998f237137f1417ab07ec06358ea04f69134f8", size = 81753, upload-time = "2025-06-10T00:43:03.486Z" }, + { url = "https://files.pythonhosted.org/packages/8f/34/e4abde70a9256465fe31c88ed02c3f8502b7b5dead693a4f350a06413f28/yarl-1.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:aef6c4d69554d44b7f9d923245f8ad9a707d971e6209d51279196d8e8fe1ae16", size = 86817, upload-time = "2025-06-10T00:43:05.231Z" }, + { url = "https://files.pythonhosted.org/packages/b1/18/893b50efc2350e47a874c5c2d67e55a0ea5df91186b2a6f5ac52eff887cd/yarl-1.20.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:47ee6188fea634bdfaeb2cc420f5b3b17332e6225ce88149a17c413c77ff269e", size = 133833, upload-time = "2025-06-10T00:43:07.393Z" }, + { url = "https://files.pythonhosted.org/packages/89/ed/b8773448030e6fc47fa797f099ab9eab151a43a25717f9ac043844ad5ea3/yarl-1.20.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d0f6500f69e8402d513e5eedb77a4e1818691e8f45e6b687147963514d84b44b", size = 91070, upload-time = "2025-06-10T00:43:09.538Z" }, + { url = "https://files.pythonhosted.org/packages/e3/e3/409bd17b1e42619bf69f60e4f031ce1ccb29bd7380117a55529e76933464/yarl-1.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7a8900a42fcdaad568de58887c7b2f602962356908eedb7628eaf6021a6e435b", size = 89818, upload-time = "2025-06-10T00:43:11.575Z" }, + { url = "https://files.pythonhosted.org/packages/f8/77/64d8431a4d77c856eb2d82aa3de2ad6741365245a29b3a9543cd598ed8c5/yarl-1.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bad6d131fda8ef508b36be3ece16d0902e80b88ea7200f030a0f6c11d9e508d4", size = 347003, upload-time = "2025-06-10T00:43:14.088Z" }, + { url = "https://files.pythonhosted.org/packages/8d/d2/0c7e4def093dcef0bd9fa22d4d24b023788b0a33b8d0088b51aa51e21e99/yarl-1.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:df018d92fe22aaebb679a7f89fe0c0f368ec497e3dda6cb81a567610f04501f1", size = 336537, upload-time = "2025-06-10T00:43:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/f0/f3/fc514f4b2cf02cb59d10cbfe228691d25929ce8f72a38db07d3febc3f706/yarl-1.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f969afbb0a9b63c18d0feecf0db09d164b7a44a053e78a7d05f5df163e43833", size = 362358, upload-time = "2025-06-10T00:43:18.704Z" }, + { url = "https://files.pythonhosted.org/packages/ea/6d/a313ac8d8391381ff9006ac05f1d4331cee3b1efaa833a53d12253733255/yarl-1.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:812303eb4aa98e302886ccda58d6b099e3576b1b9276161469c25803a8db277d", size = 357362, upload-time = "2025-06-10T00:43:20.888Z" }, + { url = "https://files.pythonhosted.org/packages/00/70/8f78a95d6935a70263d46caa3dd18e1f223cf2f2ff2037baa01a22bc5b22/yarl-1.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98c4a7d166635147924aa0bf9bfe8d8abad6fffa6102de9c99ea04a1376f91e8", size = 348979, upload-time = "2025-06-10T00:43:23.169Z" }, + { url = "https://files.pythonhosted.org/packages/cb/05/42773027968968f4f15143553970ee36ead27038d627f457cc44bbbeecf3/yarl-1.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12e768f966538e81e6e7550f9086a6236b16e26cd964cf4df35349970f3551cf", size = 337274, upload-time = "2025-06-10T00:43:27.111Z" }, + { url = "https://files.pythonhosted.org/packages/05/be/665634aa196954156741ea591d2f946f1b78ceee8bb8f28488bf28c0dd62/yarl-1.20.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe41919b9d899661c5c28a8b4b0acf704510b88f27f0934ac7a7bebdd8938d5e", size = 363294, upload-time = "2025-06-10T00:43:28.96Z" }, + { url = "https://files.pythonhosted.org/packages/eb/90/73448401d36fa4e210ece5579895731f190d5119c4b66b43b52182e88cd5/yarl-1.20.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:8601bc010d1d7780592f3fc1bdc6c72e2b6466ea34569778422943e1a1f3c389", size = 358169, upload-time = "2025-06-10T00:43:30.701Z" }, + { url = "https://files.pythonhosted.org/packages/c3/b0/fce922d46dc1eb43c811f1889f7daa6001b27a4005587e94878570300881/yarl-1.20.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:daadbdc1f2a9033a2399c42646fbd46da7992e868a5fe9513860122d7fe7a73f", size = 362776, upload-time = "2025-06-10T00:43:32.51Z" }, + { url = "https://files.pythonhosted.org/packages/f1/0d/b172628fce039dae8977fd22caeff3eeebffd52e86060413f5673767c427/yarl-1.20.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:03aa1e041727cb438ca762628109ef1333498b122e4c76dd858d186a37cec845", size = 381341, upload-time = "2025-06-10T00:43:34.543Z" }, + { url = "https://files.pythonhosted.org/packages/6b/9b/5b886d7671f4580209e855974fe1cecec409aa4a89ea58b8f0560dc529b1/yarl-1.20.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:642980ef5e0fa1de5fa96d905c7e00cb2c47cb468bfcac5a18c58e27dbf8d8d1", size = 379988, upload-time = "2025-06-10T00:43:36.489Z" }, + { url = "https://files.pythonhosted.org/packages/73/be/75ef5fd0fcd8f083a5d13f78fd3f009528132a1f2a1d7c925c39fa20aa79/yarl-1.20.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:86971e2795584fe8c002356d3b97ef6c61862720eeff03db2a7c86b678d85b3e", size = 371113, upload-time = "2025-06-10T00:43:38.592Z" }, + { url = "https://files.pythonhosted.org/packages/50/4f/62faab3b479dfdcb741fe9e3f0323e2a7d5cd1ab2edc73221d57ad4834b2/yarl-1.20.1-cp311-cp311-win32.whl", hash = "sha256:597f40615b8d25812f14562699e287f0dcc035d25eb74da72cae043bb884d773", size = 81485, upload-time = "2025-06-10T00:43:41.038Z" }, + { url = "https://files.pythonhosted.org/packages/f0/09/d9c7942f8f05c32ec72cd5c8e041c8b29b5807328b68b4801ff2511d4d5e/yarl-1.20.1-cp311-cp311-win_amd64.whl", hash = "sha256:26ef53a9e726e61e9cd1cda6b478f17e350fb5800b4bd1cd9fe81c4d91cfeb2e", size = 86686, upload-time = "2025-06-10T00:43:42.692Z" }, + { url = "https://files.pythonhosted.org/packages/5f/9a/cb7fad7d73c69f296eda6815e4a2c7ed53fc70c2f136479a91c8e5fbdb6d/yarl-1.20.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdcc4cd244e58593a4379fe60fdee5ac0331f8eb70320a24d591a3be197b94a9", size = 133667, upload-time = "2025-06-10T00:43:44.369Z" }, + { url = "https://files.pythonhosted.org/packages/67/38/688577a1cb1e656e3971fb66a3492501c5a5df56d99722e57c98249e5b8a/yarl-1.20.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b29a2c385a5f5b9c7d9347e5812b6f7ab267193c62d282a540b4fc528c8a9d2a", size = 91025, upload-time = "2025-06-10T00:43:46.295Z" }, + { url = "https://files.pythonhosted.org/packages/50/ec/72991ae51febeb11a42813fc259f0d4c8e0507f2b74b5514618d8b640365/yarl-1.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1112ae8154186dfe2de4732197f59c05a83dc814849a5ced892b708033f40dc2", size = 89709, upload-time = "2025-06-10T00:43:48.22Z" }, + { url = "https://files.pythonhosted.org/packages/99/da/4d798025490e89426e9f976702e5f9482005c548c579bdae792a4c37769e/yarl-1.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90bbd29c4fe234233f7fa2b9b121fb63c321830e5d05b45153a2ca68f7d310ee", size = 352287, upload-time = "2025-06-10T00:43:49.924Z" }, + { url = "https://files.pythonhosted.org/packages/1a/26/54a15c6a567aac1c61b18aa0f4b8aa2e285a52d547d1be8bf48abe2b3991/yarl-1.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:680e19c7ce3710ac4cd964e90dad99bf9b5029372ba0c7cbfcd55e54d90ea819", size = 345429, upload-time = "2025-06-10T00:43:51.7Z" }, + { url = "https://files.pythonhosted.org/packages/d6/95/9dcf2386cb875b234353b93ec43e40219e14900e046bf6ac118f94b1e353/yarl-1.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a979218c1fdb4246a05efc2cc23859d47c89af463a90b99b7c56094daf25a16", size = 365429, upload-time = "2025-06-10T00:43:53.494Z" }, + { url = "https://files.pythonhosted.org/packages/91/b2/33a8750f6a4bc224242a635f5f2cff6d6ad5ba651f6edcccf721992c21a0/yarl-1.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255b468adf57b4a7b65d8aad5b5138dce6a0752c139965711bdcb81bc370e1b6", size = 363862, upload-time = "2025-06-10T00:43:55.766Z" }, + { url = "https://files.pythonhosted.org/packages/98/28/3ab7acc5b51f4434b181b0cee8f1f4b77a65919700a355fb3617f9488874/yarl-1.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a97d67108e79cfe22e2b430d80d7571ae57d19f17cda8bb967057ca8a7bf5bfd", size = 355616, upload-time = "2025-06-10T00:43:58.056Z" }, + { url = "https://files.pythonhosted.org/packages/36/a3/f666894aa947a371724ec7cd2e5daa78ee8a777b21509b4252dd7bd15e29/yarl-1.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8570d998db4ddbfb9a590b185a0a33dbf8aafb831d07a5257b4ec9948df9cb0a", size = 339954, upload-time = "2025-06-10T00:43:59.773Z" }, + { url = "https://files.pythonhosted.org/packages/f1/81/5f466427e09773c04219d3450d7a1256138a010b6c9f0af2d48565e9ad13/yarl-1.20.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:97c75596019baae7c71ccf1d8cc4738bc08134060d0adfcbe5642f778d1dca38", size = 365575, upload-time = "2025-06-10T00:44:02.051Z" }, + { url = "https://files.pythonhosted.org/packages/2e/e3/e4b0ad8403e97e6c9972dd587388940a032f030ebec196ab81a3b8e94d31/yarl-1.20.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:1c48912653e63aef91ff988c5432832692ac5a1d8f0fb8a33091520b5bbe19ef", size = 365061, upload-time = "2025-06-10T00:44:04.196Z" }, + { url = "https://files.pythonhosted.org/packages/ac/99/b8a142e79eb86c926f9f06452eb13ecb1bb5713bd01dc0038faf5452e544/yarl-1.20.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4c3ae28f3ae1563c50f3d37f064ddb1511ecc1d5584e88c6b7c63cf7702a6d5f", size = 364142, upload-time = "2025-06-10T00:44:06.527Z" }, + { url = "https://files.pythonhosted.org/packages/34/f2/08ed34a4a506d82a1a3e5bab99ccd930a040f9b6449e9fd050320e45845c/yarl-1.20.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c5e9642f27036283550f5f57dc6156c51084b458570b9d0d96100c8bebb186a8", size = 381894, upload-time = "2025-06-10T00:44:08.379Z" }, + { url = "https://files.pythonhosted.org/packages/92/f8/9a3fbf0968eac704f681726eff595dce9b49c8a25cd92bf83df209668285/yarl-1.20.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:2c26b0c49220d5799f7b22c6838409ee9bc58ee5c95361a4d7831f03cc225b5a", size = 383378, upload-time = "2025-06-10T00:44:10.51Z" }, + { url = "https://files.pythonhosted.org/packages/af/85/9363f77bdfa1e4d690957cd39d192c4cacd1c58965df0470a4905253b54f/yarl-1.20.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564ab3d517e3d01c408c67f2e5247aad4019dcf1969982aba3974b4093279004", size = 374069, upload-time = "2025-06-10T00:44:12.834Z" }, + { url = "https://files.pythonhosted.org/packages/35/99/9918c8739ba271dcd935400cff8b32e3cd319eaf02fcd023d5dcd487a7c8/yarl-1.20.1-cp312-cp312-win32.whl", hash = "sha256:daea0d313868da1cf2fac6b2d3a25c6e3a9e879483244be38c8e6a41f1d876a5", size = 81249, upload-time = "2025-06-10T00:44:14.731Z" }, + { url = "https://files.pythonhosted.org/packages/eb/83/5d9092950565481b413b31a23e75dd3418ff0a277d6e0abf3729d4d1ce25/yarl-1.20.1-cp312-cp312-win_amd64.whl", hash = "sha256:48ea7d7f9be0487339828a4de0360d7ce0efc06524a48e1810f945c45b813698", size = 86710, upload-time = "2025-06-10T00:44:16.716Z" }, + { url = "https://files.pythonhosted.org/packages/8a/e1/2411b6d7f769a07687acee88a062af5833cf1966b7266f3d8dfb3d3dc7d3/yarl-1.20.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:0b5ff0fbb7c9f1b1b5ab53330acbfc5247893069e7716840c8e7d5bb7355038a", size = 131811, upload-time = "2025-06-10T00:44:18.933Z" }, + { url = "https://files.pythonhosted.org/packages/b2/27/584394e1cb76fb771371770eccad35de400e7b434ce3142c2dd27392c968/yarl-1.20.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:14f326acd845c2b2e2eb38fb1346c94f7f3b01a4f5c788f8144f9b630bfff9a3", size = 90078, upload-time = "2025-06-10T00:44:20.635Z" }, + { url = "https://files.pythonhosted.org/packages/bf/9a/3246ae92d4049099f52d9b0fe3486e3b500e29b7ea872d0f152966fc209d/yarl-1.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f60e4ad5db23f0b96e49c018596707c3ae89f5d0bd97f0ad3684bcbad899f1e7", size = 88748, upload-time = "2025-06-10T00:44:22.34Z" }, + { url = "https://files.pythonhosted.org/packages/a3/25/35afe384e31115a1a801fbcf84012d7a066d89035befae7c5d4284df1e03/yarl-1.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49bdd1b8e00ce57e68ba51916e4bb04461746e794e7c4d4bbc42ba2f18297691", size = 349595, upload-time = "2025-06-10T00:44:24.314Z" }, + { url = "https://files.pythonhosted.org/packages/28/2d/8aca6cb2cabc8f12efcb82749b9cefecbccfc7b0384e56cd71058ccee433/yarl-1.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:66252d780b45189975abfed839616e8fd2dbacbdc262105ad7742c6ae58f3e31", size = 342616, upload-time = "2025-06-10T00:44:26.167Z" }, + { url = "https://files.pythonhosted.org/packages/0b/e9/1312633d16b31acf0098d30440ca855e3492d66623dafb8e25b03d00c3da/yarl-1.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59174e7332f5d153d8f7452a102b103e2e74035ad085f404df2e40e663a22b28", size = 361324, upload-time = "2025-06-10T00:44:27.915Z" }, + { url = "https://files.pythonhosted.org/packages/bc/a0/688cc99463f12f7669eec7c8acc71ef56a1521b99eab7cd3abb75af887b0/yarl-1.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e3968ec7d92a0c0f9ac34d5ecfd03869ec0cab0697c91a45db3fbbd95fe1b653", size = 359676, upload-time = "2025-06-10T00:44:30.041Z" }, + { url = "https://files.pythonhosted.org/packages/af/44/46407d7f7a56e9a85a4c207724c9f2c545c060380718eea9088f222ba697/yarl-1.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1a4fbb50e14396ba3d375f68bfe02215d8e7bc3ec49da8341fe3157f59d2ff5", size = 352614, upload-time = "2025-06-10T00:44:32.171Z" }, + { url = "https://files.pythonhosted.org/packages/b1/91/31163295e82b8d5485d31d9cf7754d973d41915cadce070491778d9c9825/yarl-1.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11a62c839c3a8eac2410e951301309426f368388ff2f33799052787035793b02", size = 336766, upload-time = "2025-06-10T00:44:34.494Z" }, + { url = "https://files.pythonhosted.org/packages/b4/8e/c41a5bc482121f51c083c4c2bcd16b9e01e1cf8729e380273a952513a21f/yarl-1.20.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:041eaa14f73ff5a8986b4388ac6bb43a77f2ea09bf1913df7a35d4646db69e53", size = 364615, upload-time = "2025-06-10T00:44:36.856Z" }, + { url = "https://files.pythonhosted.org/packages/e3/5b/61a3b054238d33d70ea06ebba7e58597891b71c699e247df35cc984ab393/yarl-1.20.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:377fae2fef158e8fd9d60b4c8751387b8d1fb121d3d0b8e9b0be07d1b41e83dc", size = 360982, upload-time = "2025-06-10T00:44:39.141Z" }, + { url = "https://files.pythonhosted.org/packages/df/a3/6a72fb83f8d478cb201d14927bc8040af901811a88e0ff2da7842dd0ed19/yarl-1.20.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1c92f4390e407513f619d49319023664643d3339bd5e5a56a3bebe01bc67ec04", size = 369792, upload-time = "2025-06-10T00:44:40.934Z" }, + { url = "https://files.pythonhosted.org/packages/7c/af/4cc3c36dfc7c077f8dedb561eb21f69e1e9f2456b91b593882b0b18c19dc/yarl-1.20.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d25ddcf954df1754ab0f86bb696af765c5bfaba39b74095f27eececa049ef9a4", size = 382049, upload-time = "2025-06-10T00:44:42.854Z" }, + { url = "https://files.pythonhosted.org/packages/19/3a/e54e2c4752160115183a66dc9ee75a153f81f3ab2ba4bf79c3c53b33de34/yarl-1.20.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:909313577e9619dcff8c31a0ea2aa0a2a828341d92673015456b3ae492e7317b", size = 384774, upload-time = "2025-06-10T00:44:45.275Z" }, + { url = "https://files.pythonhosted.org/packages/9c/20/200ae86dabfca89060ec6447649f219b4cbd94531e425e50d57e5f5ac330/yarl-1.20.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:793fd0580cb9664548c6b83c63b43c477212c0260891ddf86809e1c06c8b08f1", size = 374252, upload-time = "2025-06-10T00:44:47.31Z" }, + { url = "https://files.pythonhosted.org/packages/83/75/11ee332f2f516b3d094e89448da73d557687f7d137d5a0f48c40ff211487/yarl-1.20.1-cp313-cp313-win32.whl", hash = "sha256:468f6e40285de5a5b3c44981ca3a319a4b208ccc07d526b20b12aeedcfa654b7", size = 81198, upload-time = "2025-06-10T00:44:49.164Z" }, + { url = "https://files.pythonhosted.org/packages/ba/ba/39b1ecbf51620b40ab402b0fc817f0ff750f6d92712b44689c2c215be89d/yarl-1.20.1-cp313-cp313-win_amd64.whl", hash = "sha256:495b4ef2fea40596bfc0affe3837411d6aa3371abcf31aac0ccc4bdd64d4ef5c", size = 86346, upload-time = "2025-06-10T00:44:51.182Z" }, + { url = "https://files.pythonhosted.org/packages/43/c7/669c52519dca4c95153c8ad96dd123c79f354a376346b198f438e56ffeb4/yarl-1.20.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:f60233b98423aab21d249a30eb27c389c14929f47be8430efa7dbd91493a729d", size = 138826, upload-time = "2025-06-10T00:44:52.883Z" }, + { url = "https://files.pythonhosted.org/packages/6a/42/fc0053719b44f6ad04a75d7f05e0e9674d45ef62f2d9ad2c1163e5c05827/yarl-1.20.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6f3eff4cc3f03d650d8755c6eefc844edde99d641d0dcf4da3ab27141a5f8ddf", size = 93217, upload-time = "2025-06-10T00:44:54.658Z" }, + { url = "https://files.pythonhosted.org/packages/4f/7f/fa59c4c27e2a076bba0d959386e26eba77eb52ea4a0aac48e3515c186b4c/yarl-1.20.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:69ff8439d8ba832d6bed88af2c2b3445977eba9a4588b787b32945871c2444e3", size = 92700, upload-time = "2025-06-10T00:44:56.784Z" }, + { url = "https://files.pythonhosted.org/packages/2f/d4/062b2f48e7c93481e88eff97a6312dca15ea200e959f23e96d8ab898c5b8/yarl-1.20.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cf34efa60eb81dd2645a2e13e00bb98b76c35ab5061a3989c7a70f78c85006d", size = 347644, upload-time = "2025-06-10T00:44:59.071Z" }, + { url = "https://files.pythonhosted.org/packages/89/47/78b7f40d13c8f62b499cc702fdf69e090455518ae544c00a3bf4afc9fc77/yarl-1.20.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:8e0fe9364ad0fddab2688ce72cb7a8e61ea42eff3c7caeeb83874a5d479c896c", size = 323452, upload-time = "2025-06-10T00:45:01.605Z" }, + { url = "https://files.pythonhosted.org/packages/eb/2b/490d3b2dc66f52987d4ee0d3090a147ea67732ce6b4d61e362c1846d0d32/yarl-1.20.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f64fbf81878ba914562c672024089e3401974a39767747691c65080a67b18c1", size = 346378, upload-time = "2025-06-10T00:45:03.946Z" }, + { url = "https://files.pythonhosted.org/packages/66/ad/775da9c8a94ce925d1537f939a4f17d782efef1f973039d821cbe4bcc211/yarl-1.20.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6342d643bf9a1de97e512e45e4b9560a043347e779a173250824f8b254bd5ce", size = 353261, upload-time = "2025-06-10T00:45:05.992Z" }, + { url = "https://files.pythonhosted.org/packages/4b/23/0ed0922b47a4f5c6eb9065d5ff1e459747226ddce5c6a4c111e728c9f701/yarl-1.20.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56dac5f452ed25eef0f6e3c6a066c6ab68971d96a9fb441791cad0efba6140d3", size = 335987, upload-time = "2025-06-10T00:45:08.227Z" }, + { url = "https://files.pythonhosted.org/packages/3e/49/bc728a7fe7d0e9336e2b78f0958a2d6b288ba89f25a1762407a222bf53c3/yarl-1.20.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7d7f497126d65e2cad8dc5f97d34c27b19199b6414a40cb36b52f41b79014be", size = 329361, upload-time = "2025-06-10T00:45:10.11Z" }, + { url = "https://files.pythonhosted.org/packages/93/8f/b811b9d1f617c83c907e7082a76e2b92b655400e61730cd61a1f67178393/yarl-1.20.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:67e708dfb8e78d8a19169818eeb5c7a80717562de9051bf2413aca8e3696bf16", size = 346460, upload-time = "2025-06-10T00:45:12.055Z" }, + { url = "https://files.pythonhosted.org/packages/70/fd/af94f04f275f95da2c3b8b5e1d49e3e79f1ed8b6ceb0f1664cbd902773ff/yarl-1.20.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:595c07bc79af2494365cc96ddeb772f76272364ef7c80fb892ef9d0649586513", size = 334486, upload-time = "2025-06-10T00:45:13.995Z" }, + { url = "https://files.pythonhosted.org/packages/84/65/04c62e82704e7dd0a9b3f61dbaa8447f8507655fd16c51da0637b39b2910/yarl-1.20.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7bdd2f80f4a7df852ab9ab49484a4dee8030023aa536df41f2d922fd57bf023f", size = 342219, upload-time = "2025-06-10T00:45:16.479Z" }, + { url = "https://files.pythonhosted.org/packages/91/95/459ca62eb958381b342d94ab9a4b6aec1ddec1f7057c487e926f03c06d30/yarl-1.20.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:c03bfebc4ae8d862f853a9757199677ab74ec25424d0ebd68a0027e9c639a390", size = 350693, upload-time = "2025-06-10T00:45:18.399Z" }, + { url = "https://files.pythonhosted.org/packages/a6/00/d393e82dd955ad20617abc546a8f1aee40534d599ff555ea053d0ec9bf03/yarl-1.20.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:344d1103e9c1523f32a5ed704d576172d2cabed3122ea90b1d4e11fe17c66458", size = 355803, upload-time = "2025-06-10T00:45:20.677Z" }, + { url = "https://files.pythonhosted.org/packages/9e/ed/c5fb04869b99b717985e244fd93029c7a8e8febdfcffa06093e32d7d44e7/yarl-1.20.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:88cab98aa4e13e1ade8c141daeedd300a4603b7132819c484841bb7af3edce9e", size = 341709, upload-time = "2025-06-10T00:45:23.221Z" }, + { url = "https://files.pythonhosted.org/packages/24/fd/725b8e73ac2a50e78a4534ac43c6addf5c1c2d65380dd48a9169cc6739a9/yarl-1.20.1-cp313-cp313t-win32.whl", hash = "sha256:b121ff6a7cbd4abc28985b6028235491941b9fe8fe226e6fdc539c977ea1739d", size = 86591, upload-time = "2025-06-10T00:45:25.793Z" }, + { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, + { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, +]