From 4346536ab2e57917ec543b20e88c4bdc47eda572 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 16 Oct 2025 01:28:37 +0000 Subject: [PATCH] also allow regenerating assistant message by clicking it, and make sure to feed good seed to generate --- nanochat/ui.html | 173 +++++++++++++++++++++++++++++--------------- scripts/chat_web.py | 4 +- 2 files changed, 117 insertions(+), 60 deletions(-) diff --git a/nanochat/ui.html b/nanochat/ui.html index d46eeb8..b2b4605 100644 --- a/nanochat/ui.html +++ b/nanochat/ui.html @@ -108,6 +108,15 @@ background: transparent; border: none; padding: 0.25rem 0; + cursor: pointer; + border-radius: 0.5rem; + padding: 0.5rem; + margin-left: -0.5rem; + transition: background-color 0.2s ease; + } + + .message.assistant .message-content:hover { + background-color: #f9fafb; } .message.user .message-content { @@ -325,6 +334,17 @@ }); } + // Add click handler for assistant messages to enable regeneration + if (role === 'assistant' && messageIndex !== null) { + contentDiv.setAttribute('data-message-index', messageIndex); + contentDiv.setAttribute('title', 'Click to regenerate this response'); + contentDiv.addEventListener('click', function() { + if (!isGenerating) { + regenerateMessage(messageIndex); + } + }); + } + messageDiv.appendChild(contentDiv); chatWrapper.appendChild(messageDiv); @@ -358,6 +378,99 @@ chatInput.focus(); } + async function generateAssistantResponse() { + isGenerating = true; + sendButton.disabled = true; + + const assistantContent = addMessage('assistant', ''); + assistantContent.innerHTML = ''; + + try { + const response = await fetch(`${API_URL}/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + messages: messages, + temperature: currentTemperature, + top_k: currentTopK, + max_tokens: 512 + }), + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let fullResponse = ''; + assistantContent.textContent = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + const lines = chunk.split('\n'); + + for (const line of lines) { + if (line.startsWith('data: ')) { + try { + const data = JSON.parse(line.slice(6)); + if (data.token) { + fullResponse += data.token; + assistantContent.textContent = fullResponse; + chatContainer.scrollTop = chatContainer.scrollHeight; + } + } catch (e) { + } + } + } + } + + const assistantMessageIndex = messages.length; + messages.push({ role: 'assistant', content: fullResponse }); + + // Add click handler to regenerate this assistant message + assistantContent.setAttribute('data-message-index', assistantMessageIndex); + assistantContent.setAttribute('title', 'Click to regenerate this response'); + assistantContent.addEventListener('click', function() { + if (!isGenerating) { + regenerateMessage(assistantMessageIndex); + } + }); + + } catch (error) { + console.error('Error:', error); + assistantContent.innerHTML = `
Error: ${error.message}
`; + } finally { + isGenerating = false; + sendButton.disabled = !chatInput.value.trim(); + } + } + + async function regenerateMessage(messageIndex) { + // Find the message in the messages array + if (messageIndex < 0 || messageIndex >= messages.length) return; + + const messageToRegenerate = messages[messageIndex]; + if (messageToRegenerate.role !== 'assistant') return; + + // Remove this message and all subsequent messages from the array + messages = messages.slice(0, messageIndex); + + // Remove message elements from DOM starting from messageIndex + const allMessages = chatWrapper.querySelectorAll('.message'); + for (let i = messageIndex; i < allMessages.length; i++) { + allMessages[i].remove(); + } + + // Regenerate the assistant response + await generateAssistantResponse(); + } + function handleSlashCommand(command) { const parts = command.trim().split(/\s+/); const cmd = parts[0].toLowerCase(); @@ -419,72 +532,14 @@ return; } - isGenerating = true; chatInput.value = ''; chatInput.style.height = 'auto'; - sendButton.disabled = true; const userMessageIndex = messages.length; messages.push({ role: 'user', content: message }); addMessage('user', message, userMessageIndex); - const assistantContent = addMessage('assistant', ''); - assistantContent.innerHTML = ''; - - try { - const response = await fetch(`${API_URL}/chat/completions`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - messages: messages, - temperature: currentTemperature, - top_k: currentTopK, - max_tokens: 512 - }), - }); - - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } - - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let fullResponse = ''; - assistantContent.textContent = ''; - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - const chunk = decoder.decode(value); - const lines = chunk.split('\n'); - - for (const line of lines) { - if (line.startsWith('data: ')) { - try { - const data = JSON.parse(line.slice(6)); - if (data.token) { - fullResponse += data.token; - assistantContent.textContent = fullResponse; - chatContainer.scrollTop = chatContainer.scrollHeight; - } - } catch (e) { - } - } - } - } - - messages.push({ role: 'assistant', content: fullResponse }); - - } catch (error) { - console.error('Error:', error); - assistantContent.innerHTML = `
Error: ${error.message}
`; - } finally { - isGenerating = false; - sendButton.disabled = !chatInput.value.trim(); - } + await generateAssistantResponse(); } sendButton.disabled = false; diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 24258a2..c07725e 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -36,6 +36,7 @@ import os import torch import asyncio import logging +import random from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -268,7 +269,8 @@ async def generate_stream( num_samples=1, max_tokens=max_new_tokens, temperature=temperature, - top_k=top_k + top_k=top_k, + seed=random.randint(0, 2**31 - 1) ): token = token_column[0]