From 93f9b84af6fb4708dfbd2ada36de22a3318c4742 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 28 Feb 2025 10:53:59 +0100 Subject: [PATCH] Add API test for `get_alerts_dataset_info` tool --- .../get_relevant_field_names.ts | 6 +- .../functions/get_alerts_dataset_info.spec.ts | 506 ++++++++++++++++++ .../complete/functions/helpers.ts | 7 + .../apis/observability/ai_assistant/index.ts | 1 + .../common/create_llm_proxy.ts | 167 ++++-- 5 files changed, 629 insertions(+), 58 deletions(-) create mode 100644 x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts index 847d9f2980053..bc43f64189538 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts @@ -13,6 +13,8 @@ import { MessageRole, ShortIdTable, type Message } from '../../../common'; import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks'; import { FunctionCallChatFunction } from '../../service/types'; +const SELECT_RELEVANT_FIELDS_NAME = 'select_relevant_fields'; + export async function getRelevantFieldNames({ index, start, @@ -122,7 +124,7 @@ export async function getRelevantFieldNames({ ], functions: [ { - name: 'select_relevant_fields', + name: SELECT_RELEVANT_FIELDS_NAME, description: 'The IDs of the fields you consider relevant to the conversation', parameters: { type: 'object', @@ -138,7 +140,7 @@ export async function getRelevantFieldNames({ } as const, }, ], - functionCall: 'select_relevant_fields', + functionCall: SELECT_RELEVANT_FIELDS_NAME, }) ).pipe(concatenateChatCompletionChunks()); diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts new file mode 100644 index 0000000000000..971293d0cbe3b --- /dev/null +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts @@ -0,0 +1,506 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common'; +import expect from '@kbn/expect'; +import { ApmRuleType } from '@kbn/rule-data-utils'; +import { apm, timerange } from '@kbn/apm-synthtrace-client'; +import { ApmSynthtraceEsClient } from '@kbn/apm-synthtrace'; +import { RoleCredentials } from '@kbn/ftr-common-functional-services'; +import { ChatCompletionAssistantMessageParam } from 'openai/resources'; +import { ApmAlertFields } from '../../../../../../../apm_api_integration/tests/alerts/helpers/alerting_api_helper'; +import { + LlmProxy, + createLlmProxy, +} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy'; +import { getMessageAddedEvents } from './helpers'; +import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; +import { APM_ALERTS_INDEX } from '../../../apm/alerts/helpers/alerting_helper'; + +export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) { + const log = getService('log'); + const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi'); + const alertingApi = getService('alertingApi'); + const samlAuth = getService('samlAuth'); + + describe('function: get_alerts_dataset_info', function () { + // Fails on MKI: https://github.com/elastic/kibana/issues/205581 + this.tags(['failsOnMKI']); + let llmProxy: LlmProxy; + let connectorId: string; + let messageAddedEvents: MessageAddEvent[]; + let apmSynthtraceEsClient: ApmSynthtraceEsClient; + let roleAuthc: RoleCredentials; + let createdRuleId: string; + let expectedRelevantFieldNames: string[]; + + before(async () => { + ({ apmSynthtraceEsClient } = await createSyntheticApmData(getService)); + ({ roleAuthc, createdRuleId } = await createApmErrorCountRule(getService)); + + llmProxy = await createLlmProxy(log); + connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({ + port: llmProxy.getPort(), + }); + + llmProxy.interceptConversation( + { + content: '', + tool_calls: [ + { + function: { + name: 'get_alerts_dataset_info', + arguments: JSON.stringify({ start: 'now-10d', end: 'now' }), + }, + index: 0, + // @ts-expect-error + id: 'call_hFHMH5idQKW5qtoGOsmEChGE', + }, + ], + }, + { + name: 'Function request: "get_alerts_dataset_info"', + } + ); + + llmProxy.interceptToolChoice({ + toolName: 'select_relevant_fields', + response: (requestBody) => { + const messageWithFieldIds = requestBody.messages.find((message) => { + const content = message?.content as string; + return content.includes('This is the list:') && content.includes('@timestamp'); + }); + + const topFields = (messageWithFieldIds?.content as string) + .replace('This is the list:', '') + .trim() + .split('\n') + .map((line) => JSON.parse(line)) + .slice(0, 5); + + expectedRelevantFieldNames = topFields.map(({ field }) => field); + + const fieldIds = topFields.map(({ id }) => id); + return { + content: '', + tool_calls: [ + { + function: { + name: 'select_relevant_fields', + arguments: JSON.stringify({ fieldIds }), + }, + index: 0, + toolCallId: 'call_khfIcRe5V0M3ZaX2GFvFQVQt', + }, + ], + }; + }, + }); + + llmProxy.interceptConversation( + { + content: '', + tool_calls: [ + { + function: { + name: 'alerts', + arguments: JSON.stringify({ start: 'now-10d', end: 'now' }), + }, + index: 0, + // @ts-expect-error + id: 'call_asPqcc7PZvH3h645wP34CX5J', + }, + ], + }, + { + name: 'Function request: alerts', + } + ); + + llmProxy.interceptConversation(`You have active alerts for the past 10 days. Back to work!`); + + const { status, body } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'POST /internal/observability_ai_assistant/chat/complete', + params: { + body: { + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'How many alerts do I have for the past 10 days?', + }, + }, + ], + connectorId, + persist: false, + screenContexts: [], + scopes: ['observability' as const], + }, + }, + }); + + expect(status).to.be(200); + + await llmProxy.waitForAllInterceptorsSettled(); + messageAddedEvents = getMessageAddedEvents(body); + }); + + after(async () => { + llmProxy.close(); + await observabilityAIAssistantAPIClient.deleteActionConnector({ + actionId: connectorId, + }); + + await apmSynthtraceEsClient.clean(); + await alertingApi.cleanUpAlerts({ + roleAuthc, + ruleId: createdRuleId, + alertIndexName: APM_ALERTS_INDEX, + consumer: 'apm', + }); + + await samlAuth.invalidateM2mApiKeyWithRoleScope(roleAuthc); + }); + + it('sends correct ES fields', async () => { + const messageWithFields = messageAddedEvents.find( + ({ message }) => + message.message.role === MessageRole.User && + message.message.name === 'get_alerts_dataset_info' + ); + + const parsedContent = JSON.parse(messageWithFields?.message.message.content!) as { + fields: string[]; + }; + const fieldNames = parsedContent.fields.map((field) => field.split(':')[0]); + + expect(fieldNames).to.eql(expectedRelevantFieldNames); + expect(parsedContent.fields).to.eql([ + '@timestamp:date', + '_id:_id', + '_ignored:string', + '_index:_index', + '_score:number', + ]); + }); + + it('sends alerts for the given period', async () => { + const messageWithAlerts = messageAddedEvents.find( + ({ message }) => + message.message.role === MessageRole.User && message.message.name === 'alerts' + ); + + const parsedContent = JSON.parse(messageWithAlerts?.message.message.content!) as { + total: number; + alerts: any[]; + }; + expect(parsedContent.total).to.above(0); + // expect(parsedContent.alerts).to.eql(); + }); + + it('sends messageAdded events in the correct order', async () => { + expect( + messageAddedEvents.map(({ message }) => { + const { role, name, function_call: functionCall } = message.message; + if (functionCall) { + return { function_call: functionCall, role }; + } + + return { name, role }; + }) + ).to.eql([ + { + role: 'assistant', + function_call: { name: 'context', trigger: 'assistant' }, + }, + { name: 'context', role: 'user' }, + { + role: 'assistant', + function_call: { + name: 'get_alerts_dataset_info', + arguments: '{"start":"now-10d","end":"now"}', + trigger: 'assistant', + }, + }, + { name: 'get_alerts_dataset_info', role: 'user' }, + { + role: 'assistant', + function_call: { + name: 'alerts', + arguments: '{"start":"now-10d","end":"now"}', + trigger: 'assistant', + }, + }, + { name: 'alerts', role: 'user' }, + { + role: 'assistant', + function_call: { name: '', arguments: '', trigger: 'assistant' }, + }, + ]); + }); + + it('makes the right requests to the LLM proxy', async () => { + expect(llmProxy.interceptedRequests.length).to.eql(4); + + const actualRequests = llmProxy.interceptedRequests.map(({ requestBody }) => { + const messages = requestBody.messages.map((message) => { + const { + role, + content, + tool_calls: toolCalls, + } = message as ChatCompletionAssistantMessageParam; + + let formattedContent = ''; + if (content) { + // @ts-expect-error + formattedContent = content?.includes('This is the list:') + ? 'This is the list:' + : content; + + formattedContent = + formattedContent.length > 100 + ? `${formattedContent.slice(0, 100)}...` + : formattedContent; + } + + return { + role, + content: formattedContent, + ...(toolCalls ? { tool_calls: toolCalls[0].function } : {}), + }; + }); + + return { + messages, + toolChoice: requestBody.tool_choice, + }; + }); + + expect(actualRequests).to.eql([ + { + messages: [ + { + role: 'system', + content: + 'You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observabilit...', + }, + { + role: 'user', + content: 'How many alerts do I have for the past 10 days?', + }, + { + role: 'assistant', + content: '', + tool_calls: { + name: 'context', + arguments: '{}', + }, + }, + { + role: 'tool', + content: '{"screen_description":"","learnings":[]}', + }, + ], + toolChoice: 'auto', + }, + { + messages: [ + { + role: 'system', + content: + 'You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observabilit...', + }, + { + role: 'user', + content: 'How many alerts do I have for the past 10 days?', + }, + { + role: 'assistant', + content: '', + tool_calls: { + name: 'context', + arguments: '{}', + }, + }, + { + role: 'tool', + content: '{"screen_description":"","learnings":[]}', + }, + { + role: 'user', + content: 'This is the list:', + }, + ], + toolChoice: { + function: { + name: 'select_relevant_fields', + }, + type: 'function', + }, + }, + { + messages: [ + { + role: 'system', + content: + 'You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observabilit...', + }, + { + role: 'user', + content: 'How many alerts do I have for the past 10 days?', + }, + { + role: 'assistant', + content: '', + tool_calls: { + name: 'context', + arguments: '{}', + }, + }, + { + role: 'tool', + content: '{"screen_description":"","learnings":[]}', + }, + { + role: 'assistant', + content: '', + tool_calls: { + name: 'get_alerts_dataset_info', + arguments: '{"start":"now-10d","end":"now"}', + }, + }, + { + role: 'tool', + content: + '{"fields":["@timestamp:date","_id:_id","_ignored:string","_index:_index","_score:number"]}', + }, + ], + toolChoice: 'auto', + }, + { + messages: [ + { + role: 'system', + content: + 'You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observabilit...', + }, + { + role: 'user', + content: 'How many alerts do I have for the past 10 days?', + }, + { + role: 'assistant', + content: '', + tool_calls: { + name: 'context', + arguments: '{}', + }, + }, + { + role: 'tool', + content: '{"screen_description":"","learnings":[]}', + }, + { + role: 'assistant', + content: '', + tool_calls: { + name: 'get_alerts_dataset_info', + arguments: '{"start":"now-10d","end":"now"}', + }, + }, + { + role: 'tool', + content: + '{"fields":["@timestamp:date","_id:_id","_ignored:string","_index:_index","_score:number"]}', + }, + { + role: 'assistant', + content: '', + tool_calls: { + name: 'alerts', + arguments: '{"start":"now-10d","end":"now"}', + }, + }, + { + role: 'tool', + content: + '{"total":1,"alerts":[{"processor.event":"error","kibana.alert.evaluation.value":15,"kibana.alert.eva...', + }, + ], + toolChoice: 'auto', + }, + ]); + }); + }); +} + +async function createApmErrorCountRule( + getService: DeploymentAgnosticFtrProviderContext['getService'] +) { + const alertingApi = getService('alertingApi'); + const samlAuth = getService('samlAuth'); + + const roleAuthc = await samlAuth.createM2mApiKeyWithRoleScope('editor'); + const createdRule = await alertingApi.createRule({ + ruleTypeId: ApmRuleType.ErrorCount, + name: 'APM error threshold', + consumer: 'apm', + schedule: { interval: '1m' }, + tags: ['apm'], + params: { + environment: 'production', + threshold: 1, + windowSize: 1, + windowUnit: 'h', + }, + roleAuthc, + }); + + const createdRuleId = createdRule.id as string; + const esResponse = await alertingApi.waitForDocumentInIndex({ + indexName: APM_ALERTS_INDEX, + ruleId: createdRuleId, + docCountTarget: 1, + }); + + return { + roleAuthc, + createdRuleId, + alerts: esResponse.hits.hits.map((hit) => hit._source!), + }; +} + +async function createSyntheticApmData( + getService: DeploymentAgnosticFtrProviderContext['getService'] +) { + const synthtrace = getService('synthtrace'); + const apmSynthtraceEsClient = await synthtrace.createApmSynthtraceEsClient(); + + const opbeansNode = apm + .service({ name: 'opbeans-node', environment: 'production', agentName: 'node' }) + .instance('instance'); + + const events = timerange('now-15m', 'now') + .ratePerMinute(1) + .generator((timestamp) => { + return [ + opbeansNode + .transaction({ transactionName: 'DELETE /user/:id' }) + .timestamp(timestamp) + .duration(100) + .failure() + .errors( + opbeansNode.error({ message: 'Unable to delete user' }).timestamp(timestamp + 50) + ), + ]; + }); + + await apmSynthtraceEsClient.index(events); + + return { apmSynthtraceEsClient }; +} diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts index b64295d3a255b..f36b9b9eb6037 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts @@ -45,6 +45,13 @@ export async function invokeChatCompleteWithFunctionRequest({ params: { body: { messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello from user', + }, + }, { '@timestamp': new Date().toISOString(), message: { diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts index 1d3d41ddb4400..31d0b5f5c836c 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts @@ -16,6 +16,7 @@ export default function aiAssistantApiIntegrationTests({ loadTestFile(require.resolve('./chat/chat.spec.ts')); loadTestFile(require.resolve('./complete/complete.spec.ts')); loadTestFile(require.resolve('./complete/functions/alerts.spec.ts')); + loadTestFile(require.resolve('./complete/functions/get_alerts_dataset_info.spec.ts')); loadTestFile(require.resolve('./complete/functions/elasticsearch.spec.ts')); loadTestFile(require.resolve('./complete/functions/summarize.spec.ts')); loadTestFile(require.resolve('./public_complete/public_complete.spec.ts')); diff --git a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts index 5fcf41a33ebdb..ef2b87075e0b5 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts @@ -8,7 +8,7 @@ import { ToolingLog } from '@kbn/tooling-log'; import getPort from 'get-port'; import http, { type Server } from 'http'; -import { isString, once, pull } from 'lodash'; +import { isString, once, pull, isFunction } from 'lodash'; import OpenAI from 'openai'; import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title'; import pRetry from 'p-retry'; @@ -18,6 +18,8 @@ import { createOpenAiChunk } from './create_openai_chunk'; type Request = http.IncomingMessage; type Response = http.ServerResponse & { req: http.IncomingMessage }; +type LLMMessage = string[] | ToolMessage | string | undefined; + type RequestHandler = ( request: Request, response: Response, @@ -46,35 +48,47 @@ export interface LlmResponseSimulator { export class LlmProxy { server: Server; interval: NodeJS.Timeout; - interceptors: Array = []; + interceptedRequests: Array<{ + requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming; + matchingInterceptorName: string | undefined; + }> = []; constructor(private readonly port: number, private readonly log: ToolingLog) { - this.interval = setInterval(() => this.log.debug(`LLM proxy listening on port ${port}`), 1000); + this.interval = setInterval(() => this.log.debug(`LLM proxy listening on port ${port}`), 5000); this.server = http .createServer() .on('request', async (request, response) => { - this.log.info(`LLM request received`); - - const interceptors = this.interceptors.concat(); const requestBody = await getRequestBody(request); - while (interceptors.length) { - const interceptor = interceptors.shift()!; + const matchingInterceptor = this.interceptors.find(({ when }) => when(requestBody)); + this.interceptedRequests.push({ + requestBody, + matchingInterceptorName: matchingInterceptor?.name, + }); + if (matchingInterceptor) { + this.log.info(`Handling interceptor "${matchingInterceptor.name}"`); + matchingInterceptor.handle(request, response, requestBody); - if (interceptor.when(requestBody)) { - pull(this.interceptors, interceptor); - interceptor.handle(request, response, requestBody); - return; - } + this.log.debug(`Removing interceptor "${matchingInterceptor.name}"`); + pull(this.interceptors, matchingInterceptor); + return; } const errorMessage = `No interceptors found to handle request: ${request.method} ${request.url}`; + const availableInterceptorNames = this.interceptors.map(({ name }) => name); + this.log.error( + `Available interceptors: ${JSON.stringify(availableInterceptorNames, null, 2)}` + ); + this.log.error( `${errorMessage}. Messages: ${JSON.stringify(requestBody.messages, null, 2)}` ); - response.writeHead(500, { errorMessage, messages: JSON.stringify(requestBody.messages) }); + response.writeHead(500, { + 'Elastic-Interceptor': 'Interceptor not found', + }); + response.write(sseEvent({ errorMessage, availableInterceptorNames })); response.end(); }) .on('error', (error) => { @@ -88,7 +102,8 @@ export class LlmProxy { } clear() { - this.interceptors.length = 0; + this.interceptors = []; + this.interceptedRequests = []; } close() { @@ -104,9 +119,11 @@ export class LlmProxy { return; } - const unsettledInterceptors = this.interceptors.map((i) => i.name).join(', '); + const unsettledInterceptors = this.interceptors.map((i) => i.name); this.log.debug( - `Waiting for the following interceptors to be called: ${unsettledInterceptors}` + `Waiting for the following interceptors to be called: ${JSON.stringify( + unsettledInterceptors + )}` ); if (this.interceptors.length > 0) { throw new Error(`Interceptors were not called: ${unsettledInterceptors}`); @@ -120,24 +137,47 @@ export class LlmProxy { } interceptConversation( - msg: Array | ToolMessage | string | undefined, + msg: LLMMessage, { - name = 'default_interceptor_conversation_name', + name, }: { name?: string; } = {} ) { return this.intercept( - name, - (body) => !isFunctionTitleRequest(body), + `Conversation interceptor: "${name ?? 'Unnamed'}"`, + (body) => { + // @ts-expect-error + return body.tool_choice?.function?.name === undefined; + }, msg ).completeAfterIntercept(); } + interceptToolChoice({ + toolName, + response, + }: { + toolName: string; + response: + | LLMMessage + | ((body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming) => LLMMessage); + }) { + return this.intercept( + `Tool choice interceptor: "${toolName}"`, + (body) => { + // @ts-expect-error + return body.tool_choice?.function?.name === toolName; + }, + response + ).completeAfterIntercept(); + } + interceptTitle(title: string) { return this.intercept( - `conversation_title_interceptor_${title.split(' ').join('_')}`, - (body) => isFunctionTitleRequest(body), + `Title interceptor: "${title}"`, + // @ts-expect-error + (body) => body.tool_choice?.function.name === TITLE_CONVERSATION_FUNCTION_NAME, { content: '', tool_calls: [ @@ -154,27 +194,22 @@ export class LlmProxy { ).completeAfterIntercept(); } - intercept< - TResponseChunks extends - | Array - | ToolMessage - | string - | undefined = undefined - >( + intercept( name: string, when: RequestInterceptor['when'], - responseChunks?: TResponseChunks - ): TResponseChunks extends undefined - ? { waitForIntercept: () => Promise } - : { completeAfterIntercept: () => Promise } { + responseChunks?: + | LLMMessage + | ((body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming) => LLMMessage) + ): { + waitForIntercept: () => Promise; + completeAfterIntercept: () => Promise; + } { const waitForInterceptPromise = Promise.race([ new Promise((outerResolve) => { this.interceptors.push({ name, when, handle: (request, response, requestBody) => { - this.log.info(`LLM request intercepted by "${name}"`); - function write(chunk: string) { return new Promise((resolve) => response.write(chunk, () => resolve())); } @@ -186,22 +221,26 @@ export class LlmProxy { requestBody, status: once(async (status: number) => { response.writeHead(status, { + 'Elastic-Interceptor': name, 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', }); }), next: (msg) => { + simulator.status(200); const chunk = createOpenAiChunk(msg); - return write(`data: ${JSON.stringify(chunk)}\n\n`); + return write(sseEvent(chunk)); }, rawWrite: (chunk: string) => { + simulator.status(200); return write(chunk); }, rawEnd: async () => { await end(); }, complete: async () => { + this.log.debug(`Completed intercept for "${name}"`); await write('data: [DONE]\n\n'); await end(); }, @@ -216,29 +255,41 @@ export class LlmProxy { }); }), new Promise((_, reject) => { - setTimeout(() => reject(new Error(`Interceptor "${name}" timed out after 20000ms`)), 20000); + setTimeout(() => reject(new Error(`Interceptor "${name}" timed out after 30000ms`)), 30000); }), ]); - if (responseChunks === undefined) { - return { waitForIntercept: () => waitForInterceptPromise } as any; - } - - const parsedChunks = Array.isArray(responseChunks) - ? responseChunks - : isString(responseChunks) - ? responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`)) - : [responseChunks]; - return { + waitForIntercept: () => waitForInterceptPromise, completeAfterIntercept: async () => { const simulator = await waitForInterceptPromise; + + function getParsedChunks(): Array { + const llmMessage = isFunction(responseChunks) + ? responseChunks(simulator.requestBody) + : responseChunks; + + if (!llmMessage) { + return []; + } + + if (Array.isArray(llmMessage)) { + return llmMessage; + } + + if (isString(llmMessage)) { + return llmMessage.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`)); + } + + return [llmMessage]; + } + + const parsedChunks = getParsedChunks(); for (const chunk of parsedChunks) { await simulator.next(chunk); } await simulator.complete(); - return simulator; }, } as any; @@ -271,11 +322,15 @@ async function getRequestBody( }); } -export function isFunctionTitleRequest( - requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming -) { - return ( - requestBody.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !== - undefined - ); +// function isFunctionTitleRequest( +// requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming +// ) { +// return ( +// requestBody.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !== +// undefined +// ); +// } + +function sseEvent(chunk: unknown) { + return `data: ${JSON.stringify(chunk)}\n\n`; }