Refactor plugin
This commit is contained in:
@@ -1,6 +1,13 @@
|
||||
import { Response } from "express";
|
||||
import { OpenAI } from "openai";
|
||||
import { Request, Response } from "express";
|
||||
import { log } from "./log";
|
||||
import { PLUGINS } from "../middlewares/plugin";
|
||||
import { sha256 } from ".";
|
||||
|
||||
declare module "express" {
|
||||
interface Request {
|
||||
provider?: string;
|
||||
}
|
||||
}
|
||||
|
||||
interface ContentBlock {
|
||||
type: string;
|
||||
@@ -42,28 +49,61 @@ interface MessageEvent {
|
||||
}
|
||||
|
||||
export async function streamOpenAIResponse(
|
||||
req: Request,
|
||||
res: Response,
|
||||
completion: any,
|
||||
model: string,
|
||||
body: any
|
||||
_completion: any
|
||||
) {
|
||||
const write = (data: string) => {
|
||||
log("response: ", data);
|
||||
res.write(data);
|
||||
let completion = _completion;
|
||||
res.locals.completion = completion;
|
||||
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.beforeTransformResponse) {
|
||||
const result = await plugin.beforeTransformResponse(req, res, {
|
||||
completion,
|
||||
});
|
||||
if (result) {
|
||||
completion = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
const write = async (data: string) => {
|
||||
let eventData = data;
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.afterTransformResponse) {
|
||||
const hookResult = await plugin.afterTransformResponse(req, res, {
|
||||
completion: res.locals.completion,
|
||||
transformedCompletion: eventData,
|
||||
});
|
||||
if (typeof hookResult === "string") {
|
||||
eventData = hookResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (eventData) {
|
||||
log("response: ", eventData);
|
||||
res.write(eventData);
|
||||
}
|
||||
};
|
||||
const messageId = "msg_" + Date.now();
|
||||
if (!body.stream) {
|
||||
if (!req.body.stream) {
|
||||
let content: any = [];
|
||||
if (completion.choices[0].message.content) {
|
||||
content = [ { text: completion.choices[0].message.content, type: "text" } ];
|
||||
}
|
||||
else if (completion.choices[0].message.tool_calls) {
|
||||
content = [{ text: completion.choices[0].message.content, type: "text" }];
|
||||
} else if (completion.choices[0].message.tool_calls) {
|
||||
content = completion.choices[0].message.tool_calls.map((item: any) => {
|
||||
return {
|
||||
type: 'tool_use',
|
||||
type: "tool_use",
|
||||
id: item.id,
|
||||
name: item.function?.name,
|
||||
input: item.function?.arguments ? JSON.parse(item.function.arguments) : {},
|
||||
input: item.function?.arguments
|
||||
? JSON.parse(item.function.arguments)
|
||||
: {},
|
||||
};
|
||||
});
|
||||
}
|
||||
@@ -74,11 +114,29 @@ export async function streamOpenAIResponse(
|
||||
role: "assistant",
|
||||
// @ts-ignore
|
||||
content: content,
|
||||
stop_reason: completion.choices[0].finish_reason === 'tool_calls' ? "tool_use" : "end_turn",
|
||||
stop_reason:
|
||||
completion.choices[0].finish_reason === "tool_calls"
|
||||
? "tool_use"
|
||||
: "end_turn",
|
||||
stop_sequence: null,
|
||||
};
|
||||
try {
|
||||
res.json(result);
|
||||
res.locals.transformedCompletion = result;
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.afterTransformResponse) {
|
||||
const hookResult = await plugin.afterTransformResponse(req, res, {
|
||||
completion: res.locals.completion,
|
||||
transformedCompletion: res.locals.transformedCompletion,
|
||||
});
|
||||
if (hookResult) {
|
||||
res.locals.transformedCompletion = hookResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
res.json(res.locals.transformedCompletion);
|
||||
res.end();
|
||||
return;
|
||||
} catch (error) {
|
||||
@@ -98,7 +156,7 @@ export async function streamOpenAIResponse(
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [],
|
||||
model,
|
||||
model: req.body.model,
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage: { input_tokens: 1, output_tokens: 1 },
|
||||
@@ -118,11 +176,13 @@ export async function streamOpenAIResponse(
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
if (delta.tool_calls && delta.tool_calls.length > 0) {
|
||||
for (const toolCall of delta.tool_calls) {
|
||||
const toolCallId = toolCall.id;
|
||||
|
||||
// Check if this is a new tool call by ID
|
||||
if (toolCallId && toolCallId !== currentToolCallId) {
|
||||
// Handle each tool call in the current chunk
|
||||
for (const [index, toolCall] of delta.tool_calls.entries()) {
|
||||
// Generate a stable ID for this tool call position
|
||||
const toolCallId = toolCall.id || `tool_${index}`;
|
||||
|
||||
// If this position doesn't have an active tool call, start a new one
|
||||
if (!toolCallJsonMap.has(`${index}`)) {
|
||||
// End previous tool call if one was active
|
||||
if (isToolUse && currentToolCallId) {
|
||||
const contentBlockStop: MessageEvent = {
|
||||
@@ -138,13 +198,13 @@ export async function streamOpenAIResponse(
|
||||
|
||||
// Start new tool call block
|
||||
isToolUse = true;
|
||||
currentToolCallId = toolCallId;
|
||||
currentToolCallId = `${index}`;
|
||||
contentBlockIndex++;
|
||||
toolCallJsonMap.set(toolCallId, ""); // Initialize JSON accumulator for this tool call
|
||||
toolCallJsonMap.set(`${index}`, ""); // Initialize JSON accumulator for this tool call
|
||||
|
||||
const toolBlock: ContentBlock = {
|
||||
type: "tool_use",
|
||||
id: toolCallId,
|
||||
id: toolCallId, // Use the original ID if available
|
||||
name: toolCall.function?.name,
|
||||
input: {},
|
||||
};
|
||||
@@ -164,8 +224,8 @@ export async function streamOpenAIResponse(
|
||||
);
|
||||
}
|
||||
|
||||
// Stream tool call JSON
|
||||
if (toolCall.function?.arguments && currentToolCallId) {
|
||||
// Stream tool call JSON for this position
|
||||
if (toolCall.function?.arguments) {
|
||||
const jsonDelta: MessageEvent = {
|
||||
type: "content_block_delta",
|
||||
index: contentBlockIndex,
|
||||
@@ -175,27 +235,39 @@ export async function streamOpenAIResponse(
|
||||
},
|
||||
};
|
||||
|
||||
// Accumulate JSON for this specific tool call
|
||||
const currentJson = toolCallJsonMap.get(currentToolCallId) || "";
|
||||
toolCallJsonMap.set(currentToolCallId, currentJson + toolCall.function.arguments);
|
||||
toolUseJson = toolCallJsonMap.get(currentToolCallId) || "";
|
||||
// Accumulate JSON for this specific tool call position
|
||||
const currentJson = toolCallJsonMap.get(`${index}`) || "";
|
||||
const newJson = currentJson + toolCall.function.arguments;
|
||||
toolCallJsonMap.set(`${index}`, newJson);
|
||||
|
||||
try {
|
||||
const parsedJson = JSON.parse(toolUseJson);
|
||||
currentContentBlocks[contentBlockIndex].input = parsedJson;
|
||||
} catch (e) {
|
||||
log("JSON parsing error (continuing to accumulate):", e);
|
||||
// JSON not yet complete, continue accumulating
|
||||
// Try to parse accumulated JSON
|
||||
if (isValidJson(newJson)) {
|
||||
try {
|
||||
const parsedJson = JSON.parse(newJson);
|
||||
const blockIndex = currentContentBlocks.findIndex(
|
||||
(block) => block.type === "tool_use" && block.id === toolCallId
|
||||
);
|
||||
if (blockIndex !== -1) {
|
||||
currentContentBlocks[blockIndex].input = parsedJson;
|
||||
}
|
||||
} catch (e) {
|
||||
log("JSON parsing error (continuing to accumulate):", e);
|
||||
}
|
||||
}
|
||||
|
||||
write(
|
||||
`event: content_block_delta\ndata: ${JSON.stringify(jsonDelta)}\n\n`
|
||||
`event: content_block_delta\ndata: ${JSON.stringify(
|
||||
jsonDelta
|
||||
)}\n\n`
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (delta.content) {
|
||||
// Handle regular text content
|
||||
if (isToolUse) {
|
||||
} else if (delta.content || chunk.choices[0].finish_reason) {
|
||||
// Handle regular text content or completion
|
||||
if (
|
||||
isToolUse &&
|
||||
(delta.content || chunk.choices[0].finish_reason === "tool_calls")
|
||||
) {
|
||||
log("Tool call ended here:", delta);
|
||||
// End previous tool call block
|
||||
const contentBlockStop: MessageEvent = {
|
||||
@@ -214,8 +286,6 @@ export async function streamOpenAIResponse(
|
||||
toolUseJson = ""; // Reset for safety
|
||||
}
|
||||
|
||||
if (!delta.content) continue;
|
||||
|
||||
// If text block not yet started, send content_block_start
|
||||
if (!hasStartedTextBlock) {
|
||||
const textBlock: ContentBlock = {
|
||||
@@ -317,18 +387,34 @@ export async function streamOpenAIResponse(
|
||||
);
|
||||
}
|
||||
|
||||
res.locals.transformedCompletion = currentContentBlocks;
|
||||
for (const [name, plugin] of PLUGINS.entries()) {
|
||||
if (name.includes(",") && !name.startsWith(`${req.provider},`)) {
|
||||
continue;
|
||||
}
|
||||
if (plugin.afterTransformResponse) {
|
||||
const hookResult = await plugin.afterTransformResponse(req, res, {
|
||||
completion: res.locals.completion,
|
||||
transformedCompletion: res.locals.transformedCompletion,
|
||||
});
|
||||
if (hookResult) {
|
||||
res.locals.transformedCompletion = hookResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send message_delta event with appropriate stop_reason
|
||||
const messageDelta: MessageEvent = {
|
||||
type: "message_delta",
|
||||
delta: {
|
||||
stop_reason: isToolUse ? "tool_use" : "end_turn",
|
||||
stop_sequence: null,
|
||||
content: currentContentBlocks,
|
||||
content: res.locals.transformedCompletion,
|
||||
},
|
||||
usage: { input_tokens: 100, output_tokens: 150 },
|
||||
};
|
||||
if (!isToolUse) {
|
||||
log("body: ", body, "messageDelta: ", messageDelta);
|
||||
log("body: ", req.body, "messageDelta: ", messageDelta);
|
||||
}
|
||||
|
||||
write(`event: message_delta\ndata: ${JSON.stringify(messageDelta)}\n\n`);
|
||||
@@ -341,3 +427,42 @@ export async function streamOpenAIResponse(
|
||||
write(`event: message_stop\ndata: ${JSON.stringify(messageStop)}\n\n`);
|
||||
res.end();
|
||||
}
|
||||
|
||||
// Add helper function at the top of the file
|
||||
function isValidJson(str: string): boolean {
|
||||
// Check if the string contains both opening and closing braces/brackets
|
||||
const hasOpenBrace = str.includes("{");
|
||||
const hasCloseBrace = str.includes("}");
|
||||
const hasOpenBracket = str.includes("[");
|
||||
const hasCloseBracket = str.includes("]");
|
||||
|
||||
// Check if we have matching pairs
|
||||
if ((hasOpenBrace && !hasCloseBrace) || (!hasOpenBrace && hasCloseBrace)) {
|
||||
return false;
|
||||
}
|
||||
if (
|
||||
(hasOpenBracket && !hasCloseBracket) ||
|
||||
(!hasOpenBracket && hasCloseBracket)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Count nested braces/brackets
|
||||
let braceCount = 0;
|
||||
let bracketCount = 0;
|
||||
|
||||
for (const char of str) {
|
||||
if (char === "{") braceCount++;
|
||||
if (char === "}") braceCount--;
|
||||
if (char === "[") bracketCount++;
|
||||
if (char === "]") bracketCount--;
|
||||
|
||||
// If we ever go negative, the JSON is invalid
|
||||
if (braceCount < 0 || bracketCount < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// All braces/brackets should be matched
|
||||
return braceCount === 0 && bracketCount === 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user