upgrading all other files to be able to use cpu/mps as well as cuda. various minor other changes ,e.g. changing max_iterations to num_iterations in sft script for consistency in naming
This commit is contained in:
@@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
|
||||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
@@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
||||
Reference in New Issue
Block a user