Skip to content

Commit

Permalink
[Security AI] Bedrock prompt tuning and inference corrections (elasti…
Browse files Browse the repository at this point in the history
…c#209011)

(cherry picked from commit 0d415a6)

# Conflicts:
#	x-pack/solutions/search/plugins/enterprise_search/public/applications/app_search/utils/encode_path_params/index.ts
#	x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.ts
  • Loading branch information
stephmilovic committed Jan 31, 2025
1 parent 17b9037 commit b2b2e1e
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const AS_PLAIN_TEXT: EuiComboBoxSingleSelectionShape = { asPlainText: true };
*/
export const EvaluationSettings: React.FC = React.memo(() => {
const { actionTypeRegistry, http, setTraceOptions, toasts, traceOptions } = useAssistantContext();
const { data: connectors } = useLoadConnectors({ http });
const { data: connectors } = useLoadConnectors({ http, inferenceEnabled: true });
const { mutate: performEvaluation, isLoading: isPerformingEvaluation } = usePerformEvaluation({
http,
toasts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
* 2.0.
*/

export { promptType } from './src/saved_object_mappings';
export { getPrompt, getPromptsByGroupId, resolveProviderAndModel } from './src/get_prompt';
export {
encodePathParams,
generateEncodedPath,
useDecodedParams,
} from '../../../shared/encode_path_params';
type PromptArray,
type Prompt,
type GetPromptArgs,
type GetPromptsByGroupIdArgs,
} from './src/types';
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import { elasticModelDictionary } from '@kbn/inference-common';
import { PromptArray, Prompt, GetPromptArgs, GetPromptsByGroupIdArgs } from './types';
import { getProviderFromActionTypeId } from './utils';
import { promptSavedObjectType } from './saved_object_mappings';

/**
* Get prompts by feature (promptGroupId)
* provide either model + provider or connector to avoid additional calls to get connector
* @param actionsClient - actions client
* @param connector - connector, provide if available. No need to provide model and provider in this case
* @param connectorId - connector id
* @param localPrompts - local prompts object
* @param model - model. No need to provide if connector provided
* @param promptGroupId - feature id, should be common across promptIds
* @param promptIds - prompt ids with shared promptGroupId
* @param provider - provider. No need to provide if connector provided
* @param savedObjectsClient - saved objects client
*/
export const getPromptsByGroupId = async ({
actionsClient,
connector,
connectorId,
localPrompts,
model: providedModel,
promptGroupId,
promptIds,
provider: providedProvider,
savedObjectsClient,
}: GetPromptsByGroupIdArgs): Promise<PromptArray> => {
const { provider, model } = await resolveProviderAndModel({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector: connector,
});

const prompts = await savedObjectsClient.find<Prompt>({
type: promptSavedObjectType,
searchFields: ['promptGroupId'],
search: promptGroupId,
});
const promptsOnly = prompts?.saved_objects.map((p) => p.attributes) ?? [];

return promptIds.map((promptId) => {
const prompt = findPromptEntry({
prompts: promptsOnly.filter((p) => p.promptId === promptId) ?? [],
promptId,
promptGroupId,
provider,
model,
localPrompts,
});
if (!prompt) {
throw new Error(
`Prompt not found for promptId: ${promptId} and promptGroupId: ${promptGroupId}`
);
}

return {
promptId,
prompt,
};
});
};

/**
* Get prompt by promptId
* provide either model + provider or connector to avoid additional calls to get connector
* @param actionsClient - actions client
* @param connector - connector, provide if available. No need to provide model and provider in this case
* @param connectorId - connector id
* @param localPrompts - local prompts object
* @param model - model. No need to provide if connector provided
* @param promptId - prompt id
* @param promptGroupId - feature id, should be common across promptIds
* @param provider - provider. No need to provide if connector provided
* @param savedObjectsClient - saved objects client
*/
export const getPrompt = async ({
actionsClient,
connector,
connectorId,
localPrompts,
model: providedModel,
promptGroupId,
promptId,
provider: providedProvider,
savedObjectsClient,
}: GetPromptArgs): Promise<string> => {
const { provider, model } = await resolveProviderAndModel({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector: connector,
});

const prompts = await savedObjectsClient.find<Prompt>({
type: promptSavedObjectType,
filter: `${promptSavedObjectType}.attributes.promptId: "${promptId}" AND ${promptSavedObjectType}.attributes.promptGroupId: "${promptGroupId}"`,
fields: ['provider', 'model', 'prompt'],
});

const prompt = findPromptEntry({
prompts: prompts?.saved_objects.map((p) => p.attributes) ?? [],
promptId,
promptGroupId,
provider,
model,
localPrompts,
});
if (!prompt) {
throw new Error(
`Prompt not found for promptId: ${promptId} and promptGroupId: ${promptGroupId}`
);
}

return prompt;
};

export const resolveProviderAndModel = async ({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector,
}: {
providedProvider?: string;
providedModel?: string;
connectorId: string;
actionsClient: PublicMethodsOf<ActionsClient>;
providedConnector?: Connector;
}): Promise<{ provider?: string; model?: string }> => {
let model = providedModel;
let provider = providedProvider;
if (!provider || !model || provider === 'inference') {
const connector = providedConnector ?? (await actionsClient.get({ id: connectorId }));

if (provider === 'inference' && connector.config) {
provider = connector.config.provider || provider;
model = connector.config.providerConfig?.model_id || model;

if (provider === 'elastic' && model) {
provider = elasticModelDictionary[model]?.provider || 'inference';
model = elasticModelDictionary[model]?.model;
}
} else if (connector.config) {
provider = provider || getProviderFromActionTypeId(connector.actionTypeId);
model = model || connector.config.defaultModel;
}
}

return { provider: provider === 'inference' ? 'bedrock' : provider, model };
};

const findPrompt = ({
prompts,
conditions,
}: {
prompts: Array<{ provider?: string; model?: string; prompt: { default: string } }>;
conditions: Array<(prompt: { provider?: string; model?: string }) => boolean>;
}): string | undefined => {
for (const condition of conditions) {
const match = prompts.find(condition);
if (match) return match.prompt.default;
}
return undefined;
};

const findPromptEntry = ({
prompts,
promptId,
promptGroupId,
provider,
model,
localPrompts,
}: {
localPrompts: Prompt[];
prompts: Prompt[];
promptId: string;
promptGroupId: string;
provider?: string;
model?: string;
}): string | undefined => {
const conditions = [
(prompt: { provider?: string; model?: string }) =>
prompt.provider === provider && prompt.model === model,
(prompt: { provider?: string; model?: string }) =>
prompt.provider === provider && !prompt.model,
(prompt: { provider?: string; model?: string }) => !prompt.provider && !prompt.model,
];

return (
findPrompt({ prompts, conditions }) ??
findPrompt({
prompts: localPrompts.filter(
(p) => p.promptId === promptId && p.promptGroupId === promptGroupId
),
conditions,
})
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => contentReferencesEnabled,
},
provider: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
};

// Default node parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ describe('streamGraph', () => {
input: 'input',
responseLanguage: 'English',
llmType: 'openai',
provider: 'openai',
connectorId: '123',
},
logger: mockLogger,
Expand Down Expand Up @@ -291,6 +292,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'gemini',
provider: 'gemini',
},
});

Expand All @@ -306,6 +308,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'bedrock',
provider: 'bedrock',
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,21 @@ export const streamGraph = async ({

// Stream is from openai functions agent
let finalMessage = '';
const stream = assistantGraph.streamEvents(inputs, {
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
});
const stream = assistantGraph.streamEvents(
inputs,
{
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
},
inputs?.provider === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
);

const pushStreamUpdate = async () => {
for await (const { event, data, tags } of stream) {
Expand All @@ -155,8 +159,6 @@ export const streamGraph = async ({
const chunk = data?.chunk;
const msg = chunk.message;
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
// however, no harm to keep it in
/* empty */
} else if (!didEnd) {
push({ payload: msg.content, type: 'content' });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ import {
} from 'langchain/agents';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
jest.mock('./graph');
jest.mock('./helpers');
jest.mock('langchain/agents');
jest.mock('@kbn/langchain/server/tracers/apm');
jest.mock('@kbn/langchain/server/tracers/telemetry');
jest.mock('@kbn/security-ai-prompts');
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
const resolveProviderAndModelMock = resolveProviderAndModel as jest.Mock;
describe('callAssistantGraph', () => {
const mockDataClients = {
anonymizationFieldsDataClient: {
Expand Down Expand Up @@ -83,6 +86,9 @@ describe('callAssistantGraph', () => {
jest.clearAllMocks();
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true);
getDefaultAssistantGraphMock.mockReturnValue({});
resolveProviderAndModelMock.mockResolvedValue({
provider: 'bedrock',
});
(invokeGraph as jest.Mock).mockResolvedValue({
output: 'test-output',
traceData: {},
Expand Down Expand Up @@ -224,5 +230,23 @@ describe('callAssistantGraph', () => {
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
expect(createToolCallingAgent).not.toHaveBeenCalled();
});
it('does not calls resolveProviderAndModel when llmType === openai', async () => {
const params = { ...defaultParams, llmType: 'openai' };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).not.toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === inference', async () => {
const params = { ...defaultParams, llmType: 'inference' };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === undefined', async () => {
const params = { ...defaultParams, llmType: undefined };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { promptGroupId } from '../../../prompt/local_prompt_object';
import { getModelOrOss } from '../../../prompt/helpers';
import { getPrompt, promptDictionary } from '../../../prompt';
Expand Down Expand Up @@ -183,6 +184,13 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
logger
)
: undefined;
const { provider } =
!llmType || llmType === 'inference'
? await resolveProviderAndModel({
connectorId,
actionsClient,
})
: { provider: llmType };
const assistantGraph = getDefaultAssistantGraph({
agentRunnable,
dataClients,
Expand All @@ -205,6 +213,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
provider: provider ?? '',
};

if (isStream) {
Expand Down
Loading

0 comments on commit b2b2e1e

Please sign in to comment.