mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Add high value engine tests for core invariants (33 LoC) (#396)
* test: add engine generation tests for expected invariants - test_seed_reproducibility - test_temperature_zero_determinism - test_max_tokens_respected - test_num_samples_count 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix temperature test * add test for seed variation in sampling Add test for seed variation in sampling with temperature > 0. * Rename test for clarity * Shorten assert msg --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
@@ -195,3 +195,72 @@ def test_multi_sample_first_token_diversity():
|
||||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
||||
|
||||
def test_seed_reproducibility():
|
||||
"""Same seed must produce identical output."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
for seed in [1, 42, 123, 999]:
|
||||
r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt."
|
||||
|
||||
|
||||
def test_temperature_zero_determinism():
|
||||
"""Temperature=0 is deterministic regardless of seed."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
|
||||
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
|
||||
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
|
||||
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
|
||||
|
||||
|
||||
def test_max_tokens_respected():
|
||||
"""Generation stops at max_tokens limit."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for max_tokens in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, max_tokens=max_tokens)
|
||||
num_generated_tokens = len(results[0]) - len(prompt)
|
||||
assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less."
|
||||
|
||||
|
||||
def test_num_samples_count():
|
||||
"""num_samples=N produces exactly N sequences."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for num_samples in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
|
||||
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
|
||||
|
||||
|
||||
def test_different_seeds_introduce_variation_when_temperature_nonzero():
|
||||
"""With temperature > 0, different seeds should introduce sampling variation."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
outputs = set()
|
||||
|
||||
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
|
||||
results, _ = engine.generate_batch(
|
||||
prompt,
|
||||
temperature=1.0,
|
||||
max_tokens=5,
|
||||
seed=seed,
|
||||
)
|
||||
outputs.add(tuple(results[0]))
|
||||
|
||||
# Sanity check: sampling actually introduces variation
|
||||
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."
|
||||
|
||||
Reference in New Issue
Block a user