initial commit
This commit is contained in:
106
scripts/tok_train.py
Normal file
106
scripts/tok_train.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Train a tokenizer using the HuggingFace Tokenizers library.
|
||||
In the style of GPT-4 tokenizer.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
print(f"vocab_size: {args.vocab_size:,}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Text iterator
|
||||
|
||||
def text_iterator():
|
||||
"""
|
||||
1) Flatten the batches into a single iterator
|
||||
2) Crop every document to args.doc_cap characters
|
||||
3) Break when we've seen args.max_chars characters
|
||||
"""
|
||||
nchars = 0
|
||||
for batch in parquets_iter_batched(split="train"):
|
||||
for doc in batch:
|
||||
doc_text = doc
|
||||
if len(doc_text) > args.doc_cap:
|
||||
doc_text = doc_text[:args.doc_cap]
|
||||
nchars += len(doc_text)
|
||||
yield doc_text
|
||||
if nchars > args.max_chars:
|
||||
return
|
||||
text_iter = text_iterator()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Train the tokenizer
|
||||
t0 = time.time()
|
||||
tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size)
|
||||
t1 = time.time()
|
||||
train_time = t1 - t0
|
||||
print(f"Training time: {train_time:.2f}s")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Save the tokenizer to disk
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
tokenizer.save(tokenizer_dir)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Quick inline sanity check
|
||||
test_text = """Hello world! This is a test.
|
||||
Numbers: 123, 4567, 89
|
||||
Contractions: I'm, you're, it's
|
||||
Special chars: @#$%^&*()
|
||||
Unicode: 你好世界 🌍"""
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert decoded == test_text
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# One more thing: we wish to cache a mapping from token id to number of bytes of that token
|
||||
# for efficient evaluation of bits per byte. Unlike the typical mean loss, this
|
||||
# allows us to report a loss that is invariant to the vocab size of the tokenizer.
|
||||
# The bits per byte on the validation set is then one of the primary metrics we care about.
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
special_set = set(tokenizer.get_special_tokens())
|
||||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||||
token_bytes = []
|
||||
for token_id in range(vocab_size):
|
||||
token_str = token_strings[token_id] # the Python string representation of this token
|
||||
if token_str in special_set:
|
||||
token_bytes.append(0) # special characters are not counted
|
||||
else:
|
||||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||||
token_bytes.append(id_bytes)
|
||||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
with open(token_bytes_path, "wb") as f:
|
||||
torch.save(token_bytes, f)
|
||||
print(f"Saved token_bytes to {token_bytes_path}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
|
||||
get_report().log(section="Tokenizer training", data=[
|
||||
vars(args), # argparse command line arguments
|
||||
{"train_time": train_time},
|
||||
{"num_special_tokens": len(special_set)},
|
||||
{
|
||||
"token_bytes_min": int(token_bytes_nonzero.min().item()),
|
||||
"token_bytes_max": int(token_bytes_nonzero.max().item()),
|
||||
"token_bytes_mean": token_bytes_nonzero.mean().item(),
|
||||
"token_bytes_std": token_bytes_nonzero.std().item(),
|
||||
}
|
||||
])
|
||||
Reference in New Issue
Block a user