switch to llms
This commit is contained in:
@@ -13,7 +13,7 @@ export const REFERENCE_COUNT_FILE = '/tmp/claude-code-reference-count.txt';
|
||||
|
||||
|
||||
export const DEFAULT_CONFIG = {
|
||||
log: false,
|
||||
LOG: false,
|
||||
OPENAI_API_KEY: "",
|
||||
OPENAI_BASE_URL: "",
|
||||
OPENAI_MODEL: "",
|
||||
|
||||
103
src/index.ts
103
src/index.ts
@@ -1,19 +1,15 @@
|
||||
import { existsSync } from "fs";
|
||||
import { writeFile } from "fs/promises";
|
||||
import { getOpenAICommonOptions, initConfig, initDir } from "./utils";
|
||||
import { homedir } from "os";
|
||||
import { join } from "path";
|
||||
import { initConfig, initDir } from "./utils";
|
||||
import { createServer } from "./server";
|
||||
import { formatRequest } from "./middlewares/formatRequest";
|
||||
import { rewriteBody } from "./middlewares/rewriteBody";
|
||||
import { router } from "./middlewares/router";
|
||||
import OpenAI from "openai";
|
||||
import { streamOpenAIResponse } from "./utils/stream";
|
||||
import { router } from "./utils/router";
|
||||
import {
|
||||
cleanupPidFile,
|
||||
isServiceRunning,
|
||||
savePid,
|
||||
} from "./utils/processCheck";
|
||||
import { LRUCache } from "lru-cache";
|
||||
import { log } from "./utils/log";
|
||||
|
||||
async function initializeClaudeConfig() {
|
||||
const homeDir = process.env.HOME;
|
||||
@@ -39,13 +35,6 @@ interface RunOptions {
|
||||
port?: number;
|
||||
}
|
||||
|
||||
interface ModelProvider {
|
||||
name: string;
|
||||
api_base_url: string;
|
||||
api_key: string;
|
||||
models: string[];
|
||||
}
|
||||
|
||||
async function run(options: RunOptions = {}) {
|
||||
// Check if service is already running
|
||||
if (isServiceRunning()) {
|
||||
@@ -57,51 +46,6 @@ async function run(options: RunOptions = {}) {
|
||||
await initDir();
|
||||
const config = await initConfig();
|
||||
|
||||
const Providers = new Map<string, ModelProvider>();
|
||||
const providerCache = new LRUCache<string, OpenAI>({
|
||||
max: 10,
|
||||
ttl: 2 * 60 * 60 * 1000,
|
||||
});
|
||||
|
||||
function getProviderInstance(providerName: string): OpenAI {
|
||||
const provider: ModelProvider | undefined = Providers.get(providerName);
|
||||
if (provider === undefined) {
|
||||
throw new Error(`Provider ${providerName} not found`);
|
||||
}
|
||||
let openai = providerCache.get(provider.name);
|
||||
if (!openai) {
|
||||
openai = new OpenAI({
|
||||
baseURL: provider.api_base_url,
|
||||
apiKey: provider.api_key,
|
||||
...getOpenAICommonOptions(),
|
||||
});
|
||||
providerCache.set(provider.name, openai);
|
||||
}
|
||||
return openai;
|
||||
}
|
||||
|
||||
if (Array.isArray(config.Providers)) {
|
||||
config.Providers.forEach((provider) => {
|
||||
try {
|
||||
Providers.set(provider.name, provider);
|
||||
} catch (error) {
|
||||
console.error("Failed to parse model provider:", error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (config.OPENAI_API_KEY && config.OPENAI_BASE_URL && config.OPENAI_MODEL) {
|
||||
const defaultProvider = {
|
||||
name: "default",
|
||||
api_base_url: config.OPENAI_BASE_URL,
|
||||
api_key: config.OPENAI_API_KEY,
|
||||
models: [config.OPENAI_MODEL],
|
||||
};
|
||||
Providers.set("default", defaultProvider);
|
||||
} else if (Providers.size > 0) {
|
||||
const defaultProvider = Providers.values().next().value!;
|
||||
Providers.set("default", defaultProvider);
|
||||
}
|
||||
const port = options.port || 3456;
|
||||
|
||||
// Save the PID of the background process
|
||||
@@ -124,39 +68,16 @@ async function run(options: RunOptions = {}) {
|
||||
const servicePort = process.env.SERVICE_PORT
|
||||
? parseInt(process.env.SERVICE_PORT)
|
||||
: port;
|
||||
|
||||
const server = await createServer(servicePort);
|
||||
server.useMiddleware((req, res, next) => {
|
||||
req.config = config;
|
||||
next();
|
||||
});
|
||||
server.useMiddleware(rewriteBody);
|
||||
if (
|
||||
config.Router?.background &&
|
||||
config.Router?.think &&
|
||||
config?.Router?.longContext
|
||||
) {
|
||||
server.useMiddleware(router);
|
||||
} else {
|
||||
server.useMiddleware((req, res, next) => {
|
||||
req.provider = "default";
|
||||
req.body.model = config.OPENAI_MODEL;
|
||||
next();
|
||||
});
|
||||
}
|
||||
server.useMiddleware(formatRequest);
|
||||
|
||||
server.app.post("/v1/messages", async (req, res) => {
|
||||
try {
|
||||
const provider = getProviderInstance(req.provider || "default");
|
||||
const completion: any = await provider.chat.completions.create(req.body);
|
||||
await streamOpenAIResponse(res, completion, req.body.model, req.body);
|
||||
} catch (e) {
|
||||
log("Error in OpenAI API call:", e);
|
||||
}
|
||||
const server = createServer({
|
||||
...config,
|
||||
providers: config.Providers || config.providers,
|
||||
PORT: servicePort,
|
||||
LOG_FILE: join(homedir(), ".claude-code-router", "claude-code-router.log"),
|
||||
});
|
||||
server.addHook("preHandler", async (req, reply) =>
|
||||
router(req, reply, config)
|
||||
);
|
||||
server.start();
|
||||
console.log(`🚀 Claude Code Router is running on port ${servicePort}`);
|
||||
}
|
||||
|
||||
export { run };
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages";
|
||||
import OpenAI from "openai";
|
||||
import { streamOpenAIResponse } from "../utils/stream";
|
||||
import { log } from "../utils/log";
|
||||
|
||||
export const formatRequest = async (
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) => {
|
||||
let {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
system = [],
|
||||
temperature,
|
||||
metadata,
|
||||
tools,
|
||||
stream,
|
||||
}: MessageCreateParamsBase = req.body;
|
||||
log("formatRequest: ", req.body);
|
||||
try {
|
||||
// @ts-ignore
|
||||
const openAIMessages = Array.isArray(messages)
|
||||
? messages.flatMap((anthropicMessage) => {
|
||||
const openAiMessagesFromThisAnthropicMessage = [];
|
||||
|
||||
if (!Array.isArray(anthropicMessage.content)) {
|
||||
// Handle simple string content
|
||||
if (typeof anthropicMessage.content === "string") {
|
||||
openAiMessagesFromThisAnthropicMessage.push({
|
||||
role: anthropicMessage.role,
|
||||
content: anthropicMessage.content,
|
||||
});
|
||||
}
|
||||
// If content is not string and not array (e.g. null/undefined), it will result in an empty array, effectively skipping this message.
|
||||
return openAiMessagesFromThisAnthropicMessage;
|
||||
}
|
||||
|
||||
// Handle array content
|
||||
if (anthropicMessage.role === "assistant") {
|
||||
const assistantMessage = {
|
||||
role: "assistant",
|
||||
content: null, // Will be populated if text parts exist
|
||||
};
|
||||
let textContent = "";
|
||||
// @ts-ignore
|
||||
const toolCalls = []; // Corrected type here
|
||||
|
||||
anthropicMessage.content.forEach((contentPart) => {
|
||||
if (contentPart.type === "text") {
|
||||
textContent +=
|
||||
(typeof contentPart.text === "string"
|
||||
? contentPart.text
|
||||
: JSON.stringify(contentPart.text)) + "\\n";
|
||||
} else if (contentPart.type === "tool_use") {
|
||||
toolCalls.push({
|
||||
id: contentPart.id,
|
||||
type: "function",
|
||||
function: {
|
||||
name: contentPart.name,
|
||||
arguments: JSON.stringify(contentPart.input),
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const trimmedTextContent = textContent.trim();
|
||||
if (trimmedTextContent.length > 0) {
|
||||
// @ts-ignore
|
||||
assistantMessage.content = trimmedTextContent;
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
// @ts-ignore
|
||||
assistantMessage.tool_calls = toolCalls;
|
||||
}
|
||||
// @ts-ignore
|
||||
if (
|
||||
assistantMessage.content ||
|
||||
// @ts-ignore
|
||||
(assistantMessage.tool_calls &&
|
||||
// @ts-ignore
|
||||
assistantMessage.tool_calls.length > 0)
|
||||
) {
|
||||
openAiMessagesFromThisAnthropicMessage.push(assistantMessage);
|
||||
}
|
||||
} else if (anthropicMessage.role === "user") {
|
||||
// For user messages, text parts are combined into one message.
|
||||
// Tool results are transformed into subsequent, separate 'tool' role messages.
|
||||
let userTextMessageContent = "";
|
||||
// @ts-ignore
|
||||
const subsequentToolMessages = [];
|
||||
|
||||
anthropicMessage.content.forEach((contentPart) => {
|
||||
if (contentPart.type === "text") {
|
||||
userTextMessageContent +=
|
||||
(typeof contentPart.text === "string"
|
||||
? contentPart.text
|
||||
: JSON.stringify(contentPart.text)) + "\\n";
|
||||
} else if (contentPart.type === "tool_result") {
|
||||
// Each tool_result becomes a separate 'tool' message
|
||||
subsequentToolMessages.push({
|
||||
role: "tool",
|
||||
tool_call_id: contentPart.tool_use_id,
|
||||
content:
|
||||
typeof contentPart.content === "string"
|
||||
? contentPart.content
|
||||
: JSON.stringify(contentPart.content),
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const trimmedUserText = userTextMessageContent.trim();
|
||||
if (trimmedUserText.length > 0) {
|
||||
openAiMessagesFromThisAnthropicMessage.push({
|
||||
role: "user",
|
||||
content: trimmedUserText,
|
||||
});
|
||||
}
|
||||
// @ts-ignore
|
||||
openAiMessagesFromThisAnthropicMessage.push(
|
||||
// @ts-ignore
|
||||
...subsequentToolMessages
|
||||
);
|
||||
} else {
|
||||
// Fallback for other roles (e.g. system, or custom roles if they were to appear here with array content)
|
||||
// This will combine all text parts into a single message for that role.
|
||||
let combinedContent = "";
|
||||
anthropicMessage.content.forEach((contentPart) => {
|
||||
if (contentPart.type === "text") {
|
||||
combinedContent +=
|
||||
(typeof contentPart.text === "string"
|
||||
? contentPart.text
|
||||
: JSON.stringify(contentPart.text)) + "\\n";
|
||||
} else {
|
||||
// For non-text parts in other roles, stringify them or handle as appropriate
|
||||
combinedContent += JSON.stringify(contentPart) + "\\n";
|
||||
}
|
||||
});
|
||||
const trimmedCombinedContent = combinedContent.trim();
|
||||
if (trimmedCombinedContent.length > 0) {
|
||||
openAiMessagesFromThisAnthropicMessage.push({
|
||||
role: anthropicMessage.role, // Cast needed as role could be other than 'user'/'assistant'
|
||||
content: trimmedCombinedContent,
|
||||
});
|
||||
}
|
||||
}
|
||||
return openAiMessagesFromThisAnthropicMessage;
|
||||
})
|
||||
: [];
|
||||
const systemMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
|
||||
Array.isArray(system)
|
||||
? system.map((item) => ({
|
||||
role: "system",
|
||||
content: item.text,
|
||||
}))
|
||||
: [{ role: "system", content: system }];
|
||||
const data: any = {
|
||||
model,
|
||||
messages: [...systemMessages, ...openAIMessages],
|
||||
temperature,
|
||||
stream,
|
||||
};
|
||||
if (tools) {
|
||||
data.tools = tools
|
||||
.filter((tool) => !["StickerRequest"].includes(tool.name))
|
||||
.map((item: any) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: item.name,
|
||||
description: item.description,
|
||||
parameters: item.input_schema,
|
||||
},
|
||||
}));
|
||||
}
|
||||
if (stream) {
|
||||
res.setHeader("Content-Type", "text/event-stream");
|
||||
}
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
req.body = data;
|
||||
console.log(JSON.stringify(data.messages, null, 2));
|
||||
} catch (error) {
|
||||
console.error("Error in request processing:", error);
|
||||
const errorCompletion: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> =
|
||||
{
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
id: `error_${Date.now()}`,
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model,
|
||||
object: "chat.completion.chunk",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
content: `Error: ${(error as Error).message}`,
|
||||
},
|
||||
finish_reason: "stop",
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
};
|
||||
await streamOpenAIResponse(res, errorCompletion, model, req.body);
|
||||
}
|
||||
next();
|
||||
};
|
||||
@@ -1,45 +0,0 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import Module from "node:module";
|
||||
import { streamOpenAIResponse } from "../utils/stream";
|
||||
import { log } from "../utils/log";
|
||||
import { PLUGINS_DIR } from "../constants";
|
||||
import path from "node:path";
|
||||
import { access } from "node:fs/promises";
|
||||
import { OpenAI } from "openai";
|
||||
import { createClient } from "../utils";
|
||||
|
||||
// @ts-ignore
|
||||
const originalLoad = Module._load;
|
||||
// @ts-ignore
|
||||
Module._load = function (request, parent, isMain) {
|
||||
if (request === "claude-code-router") {
|
||||
return {
|
||||
streamOpenAIResponse,
|
||||
log,
|
||||
OpenAI,
|
||||
createClient,
|
||||
};
|
||||
}
|
||||
return originalLoad.call(this, request, parent, isMain);
|
||||
};
|
||||
|
||||
export const rewriteBody = async (
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) => {
|
||||
if (!req.config.usePlugins) {
|
||||
return next();
|
||||
}
|
||||
for (const plugin of req.config.usePlugins) {
|
||||
const pluginPath = path.join(PLUGINS_DIR, `${plugin.trim()}.js`);
|
||||
try {
|
||||
await access(pluginPath);
|
||||
const rewritePlugin = require(pluginPath);
|
||||
await rewritePlugin(req, res);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
next();
|
||||
};
|
||||
@@ -1,23 +1,8 @@
|
||||
import express, { RequestHandler } from "express";
|
||||
import Server from "@musistudio/llms";
|
||||
|
||||
interface Server {
|
||||
app: express.Application;
|
||||
useMiddleware: (middleware: RequestHandler) => void;
|
||||
start: () => void;
|
||||
}
|
||||
|
||||
export const createServer = async (port: number): Promise<Server> => {
|
||||
const app = express();
|
||||
app.use(express.json({ limit: "500mb" }));
|
||||
return {
|
||||
app,
|
||||
useMiddleware: (middleware: RequestHandler) => {
|
||||
app.use("/v1/messages", middleware);
|
||||
},
|
||||
start: () => {
|
||||
app.listen(port, () => {
|
||||
console.log(`Server is running on port ${port}`);
|
||||
});
|
||||
},
|
||||
};
|
||||
export const createServer = (config: any): Server => {
|
||||
const server = new Server({
|
||||
initialConfig: config,
|
||||
});
|
||||
return server;
|
||||
};
|
||||
|
||||
@@ -9,13 +9,6 @@ export async function executeCodeCommand(args: string[] = []) {
|
||||
// Set environment variables
|
||||
const env = {
|
||||
...process.env,
|
||||
HTTPS_PROXY: undefined,
|
||||
HTTP_PROXY: undefined,
|
||||
ALL_PROXY: undefined,
|
||||
https_proxy: undefined,
|
||||
http_proxy: undefined,
|
||||
all_proxy: undefined,
|
||||
DISABLE_PROMPT_CACHING: "1",
|
||||
ANTHROPIC_AUTH_TOKEN: "test",
|
||||
ANTHROPIC_BASE_URL: `http://127.0.0.1:3456`,
|
||||
API_TIMEOUT_MS: "600000",
|
||||
@@ -29,7 +22,7 @@ export async function executeCodeCommand(args: string[] = []) {
|
||||
const claudeProcess = spawn(claudePath, args, {
|
||||
env,
|
||||
stdio: "inherit",
|
||||
shell: true
|
||||
shell: true,
|
||||
});
|
||||
|
||||
claudeProcess.on("error", (error) => {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import { HttpsProxyAgent } from "https-proxy-agent";
|
||||
import OpenAI, { ClientOptions } from "openai";
|
||||
import fs from "node:fs/promises";
|
||||
import readline from "node:readline";
|
||||
import {
|
||||
@@ -9,16 +7,6 @@ import {
|
||||
PLUGINS_DIR,
|
||||
} from "../constants";
|
||||
|
||||
export function getOpenAICommonOptions(): ClientOptions {
|
||||
const options: ClientOptions = {};
|
||||
if (process.env.PROXY_URL) {
|
||||
options.httpAgent = new HttpsProxyAgent(process.env.PROXY_URL);
|
||||
} else if (process.env.HTTPS_PROXY) {
|
||||
options.httpAgent = new HttpsProxyAgent(process.env.HTTPS_PROXY);
|
||||
}
|
||||
return options;
|
||||
}
|
||||
|
||||
const ensureDir = async (dir_path: string) => {
|
||||
try {
|
||||
await fs.access(dir_path);
|
||||
@@ -63,9 +51,17 @@ export const readConfigFile = async () => {
|
||||
const baseUrl = await question("Enter OPENAI_BASE_URL: ");
|
||||
const model = await question("Enter OPENAI_MODEL: ");
|
||||
const config = Object.assign({}, DEFAULT_CONFIG, {
|
||||
OPENAI_API_KEY: apiKey,
|
||||
OPENAI_BASE_URL: baseUrl,
|
||||
OPENAI_MODEL: model,
|
||||
Providers: [
|
||||
{
|
||||
name: "openai",
|
||||
api_base_url: baseUrl,
|
||||
api_key: apiKey,
|
||||
models: [model],
|
||||
},
|
||||
],
|
||||
Router: {
|
||||
default: `openai,${model}`,
|
||||
},
|
||||
});
|
||||
await writeConfigFile(config);
|
||||
return config;
|
||||
@@ -82,11 +78,3 @@ export const initConfig = async () => {
|
||||
Object.assign(process.env, config);
|
||||
return config;
|
||||
};
|
||||
|
||||
export const createClient = (options: ClientOptions) => {
|
||||
const client = new OpenAI({
|
||||
...options,
|
||||
...getOpenAICommonOptions(),
|
||||
});
|
||||
return client;
|
||||
};
|
||||
|
||||
@@ -1,57 +1,32 @@
|
||||
import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages";
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { get_encoding } from "tiktoken";
|
||||
import { log } from "../utils/log";
|
||||
import { log } from "./log";
|
||||
|
||||
const enc = get_encoding("cl100k_base");
|
||||
|
||||
const getUseModel = (req: Request, tokenCount: number) => {
|
||||
const [provider, model] = req.body.model.split(",");
|
||||
if (provider && model) {
|
||||
return {
|
||||
provider,
|
||||
model,
|
||||
};
|
||||
const getUseModel = (req: any, tokenCount: number, config: any) => {
|
||||
if (req.body.model.includes(",")) {
|
||||
return req.body.model;
|
||||
}
|
||||
|
||||
// if tokenCount is greater than 32K, use the long context model
|
||||
if (tokenCount > 1000 * 32) {
|
||||
// if tokenCount is greater than 60K, use the long context model
|
||||
if (tokenCount > 1000 * 60) {
|
||||
log("Using long context model due to token count:", tokenCount);
|
||||
const [provider, model] = req.config.Router!.longContext.split(",");
|
||||
return {
|
||||
provider,
|
||||
model,
|
||||
};
|
||||
return config.Router!.longContext;
|
||||
}
|
||||
// If the model is claude-3-5-haiku, use the background model
|
||||
if (req.body.model?.startsWith("claude-3-5-haiku")) {
|
||||
log("Using background model for ", req.body.model);
|
||||
const [provider, model] = req.config.Router!.background.split(",");
|
||||
return {
|
||||
provider,
|
||||
model,
|
||||
};
|
||||
return config.Router!.background;
|
||||
}
|
||||
// if exits thinking, use the think model
|
||||
if (req.body.thinking) {
|
||||
log("Using think model for ", req.body.thinking);
|
||||
const [provider, model] = req.config.Router!.think.split(",");
|
||||
return {
|
||||
provider,
|
||||
model,
|
||||
};
|
||||
return config.Router!.think;
|
||||
}
|
||||
return {
|
||||
provider: "default",
|
||||
model: req.config.OPENAI_MODEL,
|
||||
};
|
||||
return config.Router!.default;
|
||||
};
|
||||
|
||||
export const router = async (
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) => {
|
||||
export const router = async (req: any, res: any, config: any) => {
|
||||
const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
|
||||
try {
|
||||
let tokenCount = 0;
|
||||
@@ -102,14 +77,11 @@ export const router = async (
|
||||
}
|
||||
});
|
||||
}
|
||||
const { provider, model } = getUseModel(req, tokenCount);
|
||||
req.provider = provider;
|
||||
const model = getUseModel(req, tokenCount, config);
|
||||
req.body.model = model;
|
||||
} catch (error) {
|
||||
} catch (error: any) {
|
||||
log("Error in router middleware:", error.message);
|
||||
req.provider = "default";
|
||||
req.body.model = req.config.OPENAI_MODEL;
|
||||
} finally {
|
||||
next();
|
||||
req.body.model = config.Router!.default;
|
||||
}
|
||||
return;
|
||||
};
|
||||
@@ -15,7 +15,7 @@ export function showStatus() {
|
||||
console.log('');
|
||||
console.log('🚀 Ready to use! Run the following commands:');
|
||||
console.log(' ccr code # Start coding with Claude');
|
||||
console.log(' ccr close # Stop the service');
|
||||
console.log(' ccr stop # Stop the service');
|
||||
} else {
|
||||
console.log('❌ Status: Not Running');
|
||||
console.log('');
|
||||
|
||||
@@ -1,343 +0,0 @@
|
||||
import { Response } from "express";
|
||||
import { OpenAI } from "openai";
|
||||
import { log } from "./log";
|
||||
|
||||
interface ContentBlock {
|
||||
type: string;
|
||||
id?: string;
|
||||
name?: string;
|
||||
input?: any;
|
||||
text?: string;
|
||||
}
|
||||
|
||||
interface MessageEvent {
|
||||
type: string;
|
||||
message?: {
|
||||
id: string;
|
||||
type: string;
|
||||
role: string;
|
||||
content: any[];
|
||||
model: string;
|
||||
stop_reason: string | null;
|
||||
stop_sequence: string | null;
|
||||
usage: {
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
};
|
||||
};
|
||||
delta?: {
|
||||
stop_reason?: string;
|
||||
stop_sequence?: string | null;
|
||||
content?: ContentBlock[];
|
||||
type?: string;
|
||||
text?: string;
|
||||
partial_json?: string;
|
||||
};
|
||||
index?: number;
|
||||
content_block?: ContentBlock;
|
||||
usage?: {
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
};
|
||||
}
|
||||
|
||||
export async function streamOpenAIResponse(
|
||||
res: Response,
|
||||
completion: any,
|
||||
model: string,
|
||||
body: any
|
||||
) {
|
||||
const write = (data: string) => {
|
||||
log("response: ", data);
|
||||
res.write(data);
|
||||
};
|
||||
const messageId = "msg_" + Date.now();
|
||||
if (!body.stream) {
|
||||
let content: any = [];
|
||||
if (completion.choices[0].message.content) {
|
||||
content = [ { text: completion.choices[0].message.content, type: "text" } ];
|
||||
}
|
||||
else if (completion.choices[0].message.tool_calls) {
|
||||
content = completion.choices[0].message.tool_calls.map((item: any) => {
|
||||
return {
|
||||
type: 'tool_use',
|
||||
id: item.id,
|
||||
name: item.function?.name,
|
||||
input: item.function?.arguments ? JSON.parse(item.function.arguments) : {},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
const result = {
|
||||
id: messageId,
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
// @ts-ignore
|
||||
content: content,
|
||||
stop_reason: completion.choices[0].finish_reason === 'tool_calls' ? "tool_use" : "end_turn",
|
||||
stop_sequence: null,
|
||||
};
|
||||
try {
|
||||
res.json(result);
|
||||
res.end();
|
||||
return;
|
||||
} catch (error) {
|
||||
log("Error sending response:", error);
|
||||
res.status(500).send("Internal Server Error");
|
||||
}
|
||||
}
|
||||
|
||||
let contentBlockIndex = 0;
|
||||
let currentContentBlocks: ContentBlock[] = [];
|
||||
|
||||
// Send message_start event
|
||||
const messageStart: MessageEvent = {
|
||||
type: "message_start",
|
||||
message: {
|
||||
id: messageId,
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [],
|
||||
model,
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage: { input_tokens: 1, output_tokens: 1 },
|
||||
},
|
||||
};
|
||||
write(`event: message_start\ndata: ${JSON.stringify(messageStart)}\n\n`);
|
||||
|
||||
let isToolUse = false;
|
||||
let toolUseJson = "";
|
||||
let hasStartedTextBlock = false;
|
||||
let currentToolCallId: string | null = null;
|
||||
let toolCallJsonMap = new Map<string, string>();
|
||||
|
||||
try {
|
||||
for await (const chunk of completion) {
|
||||
log("Processing chunk:", chunk);
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
if (delta.tool_calls && delta.tool_calls.length > 0) {
|
||||
for (const toolCall of delta.tool_calls) {
|
||||
const toolCallId = toolCall.id;
|
||||
|
||||
// Check if this is a new tool call by ID
|
||||
if (toolCallId && toolCallId !== currentToolCallId) {
|
||||
// End previous tool call if one was active
|
||||
if (isToolUse && currentToolCallId) {
|
||||
const contentBlockStop: MessageEvent = {
|
||||
type: "content_block_stop",
|
||||
index: contentBlockIndex,
|
||||
};
|
||||
write(
|
||||
`event: content_block_stop\ndata: ${JSON.stringify(
|
||||
contentBlockStop
|
||||
)}\n\n`
|
||||
);
|
||||
}
|
||||
|
||||
// Start new tool call block
|
||||
isToolUse = true;
|
||||
currentToolCallId = toolCallId;
|
||||
contentBlockIndex++;
|
||||
toolCallJsonMap.set(toolCallId, ""); // Initialize JSON accumulator for this tool call
|
||||
|
||||
const toolBlock: ContentBlock = {
|
||||
type: "tool_use",
|
||||
id: toolCallId,
|
||||
name: toolCall.function?.name,
|
||||
input: {},
|
||||
};
|
||||
|
||||
const toolBlockStart: MessageEvent = {
|
||||
type: "content_block_start",
|
||||
index: contentBlockIndex,
|
||||
content_block: toolBlock,
|
||||
};
|
||||
|
||||
currentContentBlocks.push(toolBlock);
|
||||
|
||||
write(
|
||||
`event: content_block_start\ndata: ${JSON.stringify(
|
||||
toolBlockStart
|
||||
)}\n\n`
|
||||
);
|
||||
}
|
||||
|
||||
// Stream tool call JSON
|
||||
if (toolCall.function?.arguments && currentToolCallId) {
|
||||
const jsonDelta: MessageEvent = {
|
||||
type: "content_block_delta",
|
||||
index: contentBlockIndex,
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: toolCall.function.arguments,
|
||||
},
|
||||
};
|
||||
|
||||
// Accumulate JSON for this specific tool call
|
||||
const currentJson = toolCallJsonMap.get(currentToolCallId) || "";
|
||||
toolCallJsonMap.set(currentToolCallId, currentJson + toolCall.function.arguments);
|
||||
toolUseJson = toolCallJsonMap.get(currentToolCallId) || "";
|
||||
|
||||
try {
|
||||
const parsedJson = JSON.parse(toolUseJson);
|
||||
currentContentBlocks[contentBlockIndex].input = parsedJson;
|
||||
} catch (e) {
|
||||
log("JSON parsing error (continuing to accumulate):", e);
|
||||
// JSON not yet complete, continue accumulating
|
||||
}
|
||||
|
||||
write(
|
||||
`event: content_block_delta\ndata: ${JSON.stringify(jsonDelta)}\n\n`
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (delta.content) {
|
||||
// Handle regular text content
|
||||
if (isToolUse) {
|
||||
log("Tool call ended here:", delta);
|
||||
// End previous tool call block
|
||||
const contentBlockStop: MessageEvent = {
|
||||
type: "content_block_stop",
|
||||
index: contentBlockIndex,
|
||||
};
|
||||
|
||||
write(
|
||||
`event: content_block_stop\ndata: ${JSON.stringify(
|
||||
contentBlockStop
|
||||
)}\n\n`
|
||||
);
|
||||
contentBlockIndex++;
|
||||
isToolUse = false;
|
||||
currentToolCallId = null;
|
||||
toolUseJson = ""; // Reset for safety
|
||||
}
|
||||
|
||||
if (!delta.content) continue;
|
||||
|
||||
// If text block not yet started, send content_block_start
|
||||
if (!hasStartedTextBlock) {
|
||||
const textBlock: ContentBlock = {
|
||||
type: "text",
|
||||
text: "",
|
||||
};
|
||||
|
||||
const textBlockStart: MessageEvent = {
|
||||
type: "content_block_start",
|
||||
index: contentBlockIndex,
|
||||
content_block: textBlock,
|
||||
};
|
||||
|
||||
currentContentBlocks.push(textBlock);
|
||||
|
||||
write(
|
||||
`event: content_block_start\ndata: ${JSON.stringify(
|
||||
textBlockStart
|
||||
)}\n\n`
|
||||
);
|
||||
hasStartedTextBlock = true;
|
||||
}
|
||||
|
||||
// Send regular text content
|
||||
const contentDelta: MessageEvent = {
|
||||
type: "content_block_delta",
|
||||
index: contentBlockIndex,
|
||||
delta: {
|
||||
type: "text_delta",
|
||||
text: delta.content,
|
||||
},
|
||||
};
|
||||
|
||||
// Update content block text
|
||||
if (currentContentBlocks[contentBlockIndex]) {
|
||||
currentContentBlocks[contentBlockIndex].text += delta.content;
|
||||
}
|
||||
|
||||
write(
|
||||
`event: content_block_delta\ndata: ${JSON.stringify(
|
||||
contentDelta
|
||||
)}\n\n`
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
// If text block not yet started, send content_block_start
|
||||
if (!hasStartedTextBlock) {
|
||||
const textBlock: ContentBlock = {
|
||||
type: "text",
|
||||
text: "",
|
||||
};
|
||||
|
||||
const textBlockStart: MessageEvent = {
|
||||
type: "content_block_start",
|
||||
index: contentBlockIndex,
|
||||
content_block: textBlock,
|
||||
};
|
||||
|
||||
currentContentBlocks.push(textBlock);
|
||||
|
||||
write(
|
||||
`event: content_block_start\ndata: ${JSON.stringify(
|
||||
textBlockStart
|
||||
)}\n\n`
|
||||
);
|
||||
hasStartedTextBlock = true;
|
||||
}
|
||||
|
||||
// Send regular text content
|
||||
const contentDelta: MessageEvent = {
|
||||
type: "content_block_delta",
|
||||
index: contentBlockIndex,
|
||||
delta: {
|
||||
type: "text_delta",
|
||||
text: JSON.stringify(e),
|
||||
},
|
||||
};
|
||||
|
||||
// Update content block text
|
||||
if (currentContentBlocks[contentBlockIndex]) {
|
||||
currentContentBlocks[contentBlockIndex].text += JSON.stringify(e);
|
||||
}
|
||||
|
||||
write(
|
||||
`event: content_block_delta\ndata: ${JSON.stringify(contentDelta)}\n\n`
|
||||
);
|
||||
}
|
||||
|
||||
// Close last content block if any is open
|
||||
if (isToolUse || hasStartedTextBlock) {
|
||||
const contentBlockStop: MessageEvent = {
|
||||
type: "content_block_stop",
|
||||
index: contentBlockIndex,
|
||||
};
|
||||
|
||||
write(
|
||||
`event: content_block_stop\ndata: ${JSON.stringify(contentBlockStop)}\n\n`
|
||||
);
|
||||
}
|
||||
|
||||
// Send message_delta event with appropriate stop_reason
|
||||
const messageDelta: MessageEvent = {
|
||||
type: "message_delta",
|
||||
delta: {
|
||||
stop_reason: isToolUse ? "tool_use" : "end_turn",
|
||||
stop_sequence: null,
|
||||
content: currentContentBlocks,
|
||||
},
|
||||
usage: { input_tokens: 100, output_tokens: 150 },
|
||||
};
|
||||
if (!isToolUse) {
|
||||
log("body: ", body, "messageDelta: ", messageDelta);
|
||||
}
|
||||
|
||||
write(`event: message_delta\ndata: ${JSON.stringify(messageDelta)}\n\n`);
|
||||
|
||||
// Send message_stop event
|
||||
const messageStop: MessageEvent = {
|
||||
type: "message_stop",
|
||||
};
|
||||
|
||||
write(`event: message_stop\ndata: ${JSON.stringify(messageStop)}\n\n`);
|
||||
res.end();
|
||||
}
|
||||
Reference in New Issue
Block a user