upgrading all other files to be able to use cpu/mps as well as cuda. various minor other changes ,e.g. changing max_iterations to num_iterations in sft script for consistency in naming
This commit is contained in:
@@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
@@ -191,11 +192,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user