diff --git a/tests/test_engine.py b/tests/test_engine.py index 9351e5a..67b8a5c 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -195,3 +195,72 @@ def test_multi_sample_first_token_diversity(): f"With uniform logits, this is statistically impossible (~10^-36 probability) " f"unless tokens are being broadcast instead of independently sampled." ) + + +def test_seed_reproducibility(): + """Same seed must produce identical output.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] # + "Hello" + + for seed in [1, 42, 123, 999]: + r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) + r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) + r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) + assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt." + + +def test_temperature_zero_determinism(): + """Temperature=0 is deterministic regardless of seed.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] + + r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1) + r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42) + r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123) + assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." + + +def test_max_tokens_respected(): + """Generation stops at max_tokens limit.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] + + for max_tokens in [1, 4, 16, 64]: + results, _ = engine.generate_batch(prompt, max_tokens=max_tokens) + num_generated_tokens = len(results[0]) - len(prompt) + assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less." + + +def test_num_samples_count(): + """num_samples=N produces exactly N sequences.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] + + for num_samples in [1, 4, 16, 64]: + results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3) + assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}" + + +def test_different_seeds_introduce_variation_when_temperature_nonzero(): + """With temperature > 0, different seeds should introduce sampling variation.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] # + "Hello" + + outputs = set() + + for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]: + results, _ = engine.generate_batch( + prompt, + temperature=1.0, + max_tokens=5, + seed=seed, + ) + outputs.add(tuple(results[0])) + + # Sanity check: sampling actually introduces variation + assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."