diff --git a/src/utils/stream.ts b/src/utils/stream.ts index c139cf3..d502947 100644 --- a/src/utils/stream.ts +++ b/src/utils/stream.ts @@ -53,39 +53,38 @@ export async function streamOpenAIResponse( }; const messageId = "msg_" + Date.now(); if (!body.stream) { - res.json({ + let content: any = []; + if (completion.choices[0].message.content) { + content = [ { text: completion.choices[0].message.content, type: "text" } ]; + } + else if (completion.choices[0].message.tool_calls) { + content = completion.choices[0].message.tool_calls.map((item: any) => { + return { + type: 'tool_use', + id: item.id, + name: item.function?.name, + input: item.function?.arguments ? JSON.parse(item.function.arguments) : {}, + }; + }); + } + + const result = { id: messageId, type: "message", role: "assistant", // @ts-ignore - content: completion.choices[0].message.content || - completion.choices[0].message.tool_calls?.map((item) => { - return { - type: "tool_use", - id: item.id, - name: item.function?.name, - input: item.function?.arguments - ? JSON.parse(item.function.arguments) - : {}, - }; - }) || [ - { - type: "text", - text: "", - }, - ], - stop_reason: - completion.choices[0].finish_reason === "tool_calls" - ? "tool_use" - : "end_turn", + content: content, + stop_reason: completion.choices[0].finish_reason === 'tool_calls' ? "tool_use" : "end_turn", stop_sequence: null, - usage: { - input_tokens: 100, - output_tokens: 50, - }, - }); - res.end(); - return; + }; + try { + res.json(result); + res.end(); + return; + } catch (error) { + log("Error sending response:", error); + res.status(500).send("Internal Server Error"); + } } let contentBlockIndex = 0; @@ -110,6 +109,8 @@ export async function streamOpenAIResponse( let isToolUse = false; let toolUseJson = ""; let hasStartedTextBlock = false; + let currentToolCallId: string | null = null; + let toolCallJsonMap = new Map(); try { for await (const chunk of completion) { @@ -117,58 +118,80 @@ export async function streamOpenAIResponse( const delta = chunk.choices[0].delta; if (delta.tool_calls && delta.tool_calls.length > 0) { - const toolCall = delta.tool_calls[0]; + for (const toolCall of delta.tool_calls) { + const toolCallId = toolCall.id; + + // Check if this is a new tool call by ID + if (toolCallId && toolCallId !== currentToolCallId) { + // End previous tool call if one was active + if (isToolUse && currentToolCallId) { + const contentBlockStop: MessageEvent = { + type: "content_block_stop", + index: contentBlockIndex, + }; + write( + `event: content_block_stop\ndata: ${JSON.stringify( + contentBlockStop + )}\n\n` + ); + } - if (!isToolUse) { - // Start new tool call block - isToolUse = true; - const toolBlock: ContentBlock = { - type: "tool_use", - id: `toolu_${Date.now()}`, - name: toolCall.function?.name, - input: {}, - }; + // Start new tool call block + isToolUse = true; + currentToolCallId = toolCallId; + contentBlockIndex++; + toolCallJsonMap.set(toolCallId, ""); // Initialize JSON accumulator for this tool call - const toolBlockStart: MessageEvent = { - type: "content_block_start", - index: contentBlockIndex, - content_block: toolBlock, - }; + const toolBlock: ContentBlock = { + type: "tool_use", + id: toolCallId, + name: toolCall.function?.name, + input: {}, + }; - currentContentBlocks.push(toolBlock); + const toolBlockStart: MessageEvent = { + type: "content_block_start", + index: contentBlockIndex, + content_block: toolBlock, + }; - write( - `event: content_block_start\ndata: ${JSON.stringify( - toolBlockStart - )}\n\n` - ); - toolUseJson = ""; - } + currentContentBlocks.push(toolBlock); - // Stream tool call JSON - if (toolCall.function?.arguments) { - const jsonDelta: MessageEvent = { - type: "content_block_delta", - index: contentBlockIndex, - delta: { - type: "input_json_delta", - partial_json: toolCall.function?.arguments, - }, - }; - - toolUseJson += toolCall.function.arguments; - - try { - const parsedJson = JSON.parse(toolUseJson); - currentContentBlocks[contentBlockIndex].input = parsedJson; - } catch (e) { - log(e); - // JSON not yet complete, continue accumulating + write( + `event: content_block_start\ndata: ${JSON.stringify( + toolBlockStart + )}\n\n` + ); } - write( - `event: content_block_delta\ndata: ${JSON.stringify(jsonDelta)}\n\n` - ); + // Stream tool call JSON + if (toolCall.function?.arguments && currentToolCallId) { + const jsonDelta: MessageEvent = { + type: "content_block_delta", + index: contentBlockIndex, + delta: { + type: "input_json_delta", + partial_json: toolCall.function.arguments, + }, + }; + + // Accumulate JSON for this specific tool call + const currentJson = toolCallJsonMap.get(currentToolCallId) || ""; + toolCallJsonMap.set(currentToolCallId, currentJson + toolCall.function.arguments); + toolUseJson = toolCallJsonMap.get(currentToolCallId) || ""; + + try { + const parsedJson = JSON.parse(toolUseJson); + currentContentBlocks[contentBlockIndex].input = parsedJson; + } catch (e) { + log("JSON parsing error (continuing to accumulate):", e); + // JSON not yet complete, continue accumulating + } + + write( + `event: content_block_delta\ndata: ${JSON.stringify(jsonDelta)}\n\n` + ); + } } } else if (delta.content) { // Handle regular text content @@ -187,6 +210,8 @@ export async function streamOpenAIResponse( ); contentBlockIndex++; isToolUse = false; + currentToolCallId = null; + toolUseJson = ""; // Reset for safety } if (!delta.content) continue; @@ -280,15 +305,17 @@ export async function streamOpenAIResponse( ); } - // Close last content block - const contentBlockStop: MessageEvent = { - type: "content_block_stop", - index: contentBlockIndex, - }; + // Close last content block if any is open + if (isToolUse || hasStartedTextBlock) { + const contentBlockStop: MessageEvent = { + type: "content_block_stop", + index: contentBlockIndex, + }; - write( - `event: content_block_stop\ndata: ${JSON.stringify(contentBlockStop)}\n\n` - ); + write( + `event: content_block_stop\ndata: ${JSON.stringify(contentBlockStop)}\n\n` + ); + } // Send message_delta event with appropriate stop_reason const messageDelta: MessageEvent = {