diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 562d517..3e89893 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -178,8 +178,9 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( doc = doc_buffer.pop(best_idx) row.extend(doc) else: - # No doc fits - crop first doc to fill remaining - doc = doc_buffer.pop(0) + # No doc fits - crop shortest in buffer to fill remaining and minimize waste + shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) + doc = doc_buffer.pop(shortest_idx) row.extend(doc[:remaining]) rows.append(row[:row_capacity])