initial commit
This commit is contained in:
99
scripts/chat_cli.py
Normal file
99
scripts/chat_cli.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
New and upgraded chat mode because a lot of the code has changed since the last one.
|
||||
|
||||
Intended to be run single GPU only atm:
|
||||
python -m scripts.chat_cli -i mid
|
||||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
parser = argparse.ArgumentParser(description='Chat with the model')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
# Create Engine for efficient generation
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
print("\nNanoChat Interactive Mode")
|
||||
print("-" * 50)
|
||||
print("Type 'quit' or 'exit' to end the conversation")
|
||||
print("Type 'clear' to start a new conversation")
|
||||
print("-" * 50)
|
||||
|
||||
conversation_tokens = [bos]
|
||||
|
||||
while True:
|
||||
|
||||
if args.prompt:
|
||||
# Get the prompt from the launch command
|
||||
user_input = args.prompt
|
||||
else:
|
||||
# Get the prompt interactively from the console
|
||||
try:
|
||||
user_input = input("\nUser: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
# Handle special commands
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'clear':
|
||||
conversation_tokens = [bos]
|
||||
print("Conversation cleared.")
|
||||
continue
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Add User message to the conversation
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(tokenizer.encode(user_input))
|
||||
conversation_tokens.append(user_end)
|
||||
|
||||
# Kick off the assistant
|
||||
conversation_tokens.append(assistant_start)
|
||||
generate_kwargs = {
|
||||
"num_samples": 1,
|
||||
"max_tokens": 256,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
}
|
||||
response_tokens = []
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
print(token_text, end="", flush=True)
|
||||
print()
|
||||
# we have to ensure that the assistant end token is the last token
|
||||
# so even if generation ends due to max tokens, we have to append it to the end
|
||||
if response_tokens[-1] != assistant_end:
|
||||
response_tokens.append(assistant_end)
|
||||
conversation_tokens.extend(response_tokens)
|
||||
|
||||
# In the prompt mode, we only want a single response and exit
|
||||
if args.prompt:
|
||||
break
|
||||
Reference in New Issue
Block a user