diff --git a/scripts/base_train.py b/scripts/base_train.py index 63f00dc..ebc5ff4 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -7,8 +7,8 @@ or distributed as: torchrun --nproc_per_node=8 base_train.py -If you just want to see it run on CPU (you won't get far but it should run), try something like: -python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --device_type=cpu --eval_tokens=512 --total_batch_size=512 --num_iterations=1000 +python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --total_batch_size=512 --num_iterations=1000 +If you have a Macbook, you're better off using device_type=mps instead of cpu """ import os