diff --git a/js/ai/src/chat.ts b/js/ai/src/chat.ts index e84131057..2bef24528 100644 --- a/js/ai/src/chat.ts +++ b/js/ai/src/chat.ts @@ -172,6 +172,7 @@ export class Chat { ...(await this.requestBase), // these things may get changed by tools calling within generate. tools: response?.request?.tools, + toolChoice: response?.request?.toolChoice, config: response?.request?.config, }); await this.updateMessages(response.messages); @@ -227,6 +228,7 @@ export class Chat { ...(await this.requestBase), // these things may get changed by tools calling within generate. tools: resolvedResponse?.request?.tools, + toolChoice: resolvedResponse?.request?.toolChoice, config: resolvedResponse?.request?.config, }); this.updateMessages(resolvedResponse.messages); diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index a6ad66b94..5dd870d81 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -47,6 +47,9 @@ import { ExecutablePrompt } from './prompt.js'; import { ToolArgument, resolveTools, toToolDefinition } from './tool.js'; export { GenerateResponse, GenerateResponseChunk }; +/** Specifies how tools should be called by the model. */ +export type ToolChoice = 'auto' | 'required' | 'none'; + export interface OutputOptions { format?: string; contentType?: string; @@ -71,6 +74,8 @@ export interface GenerateOptions< messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[]; /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ tools?: ToolArgument[]; + /** Specifies how tools should be called by the model. */ + toolChoice?: ToolChoice; /** Configuration for the generation request. */ config?: z.infer; /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ @@ -274,6 +279,7 @@ export async function generate< docs: resolvedOptions.docs, messages: injectInstructions(messages, instructions), tools, + toolChoice: resolvedOptions.toolChoice, config: { version: resolvedModel.version, ...stripUndefinedOptions(resolvedModel.config), diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 94bf6fdd1..57562b0fd 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -41,7 +41,9 @@ import { GenerateResponseData, MessageData, MessageSchema, + ModelInfo, ModelMiddleware, + ModelRequest, Part, Role, ToolDefinitionSchema, @@ -59,6 +61,8 @@ export const GenerateUtilParamSchema = z.object({ messages: z.array(MessageSchema), /** List of registered tool names for this generation if supported by the underlying model. */ tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(), + /** Tool calling mode. */ + toolChoice: z.enum(['auto', 'required', 'none']).optional(), /** Configuration for the generation request. */ config: z.any().optional(), /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ @@ -141,7 +145,8 @@ async function generate( const request = await actionToGenerateRequest( rawRequest, tools, - resolvedFormat + resolvedFormat, + model.__action.metadata?.model as ModelInfo ); const accumulatedChunks: GenerateResponseChunkData[] = []; @@ -210,6 +215,7 @@ async function generate( const toolResponses: ToolResponsePart[] = []; let messages: MessageData[] = [...request.messages, message]; let newTools = rawRequest.tools; + let newToolChoice = rawRequest.toolChoice; for (const part of toolCalls) { if (!part.toolRequest) { throw Error( @@ -235,6 +241,7 @@ async function generate( ...messages.filter((m) => !m?.metadata?.preamble), ]; newTools = newPreamble.tools; + newToolChoice = newPreamble.toolChoice; } else { toolResponses.push({ toolResponse: { @@ -255,6 +262,7 @@ async function generate( }, ] as MessageData[], tools: newTools, + toolCoice: newToolChoice, }; return await generateHelper( registry, @@ -267,9 +275,22 @@ async function generate( async function actionToGenerateRequest( options: z.infer, resolvedTools?: ToolAction[], - resolvedFormat?: Formatter + resolvedFormat?: Formatter, + modelInfo?: ModelInfo ): Promise { - const out = { + if ((options.tools?.length ?? 0) > 0 && !modelInfo?.supports?.tools) { + logger.warn( + `The model does not support tools (you set: ${options.tools?.length} tools). ` + + 'The model may not behave the way you expect.' + ); + } + if (options.toolChoice && !modelInfo?.supports?.toolChoice) { + logger.warn( + `The model does not support toolChoice option (you set: ${options.toolChoice}). ` + + 'The model may not behave the way you expect.' + ); + } + const out: ModelRequest = { messages: options.messages, config: options.config, docs: options.docs, @@ -281,7 +302,10 @@ async function actionToGenerateRequest( }), }, }; - if (!out.output.schema) delete out.output.schema; + if (options.toolChoice) { + out.toolChoice = options.toolChoice; + } + if (out.output && !out.output.schema) delete out.output.schema; return out; } diff --git a/js/ai/src/generate/response.ts b/js/ai/src/generate/response.ts index f476bbea3..41b2981c3 100644 --- a/js/ai/src/generate/response.ts +++ b/js/ai/src/generate/response.ts @@ -73,14 +73,6 @@ export class GenerateResponse implements ModelResponseData { this.request = options?.request; } - private get assertMessage(): Message { - if (!this.message) - throw new Error( - 'Operation could not be completed because the response does not contain a generated message.' - ); - return this.message; - } - /** * Throws an error if the response does not contain valid output. */ diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 71a067b46..1ff587514 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -204,6 +204,8 @@ export const ModelInfoSchema = z.object({ contentType: z.array(z.string()).optional(), /** Model can natively support document-based context grounding. */ context: z.boolean().optional(), + /** Model supports controlling tool choice, e.g. forced tool calling. */ + toolChoice: z.boolean().optional(), }) .optional(), /** At which stage of development this model is. @@ -287,6 +289,7 @@ export const ModelRequestSchema = z.object({ messages: z.array(MessageSchema), config: z.any().optional(), tools: z.array(ToolDefinitionSchema).optional(), + toolChoice: z.enum(['auto', 'required', 'none']).optional(), output: OutputConfigSchema.optional(), docs: z.array(DocumentDataSchema).optional(), }); diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 2e47d534d..d434e3575 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -21,6 +21,7 @@ import { GenerateOptions, GenerateResponse, GenerateStreamResponse, + ToolChoice, } from './generate.js'; import { GenerateRequest, @@ -29,7 +30,7 @@ import { ModelArgument, ModelMiddleware, } from './model.js'; -import { ToolAction } from './tool.js'; +import { ToolAction, ToolArgument } from './tool.js'; /** * Prompt implementation function signature. @@ -64,6 +65,8 @@ export interface PromptConfig { inputSchema?: I; inputJsonSchema?: JSONSchema7; metadata?: Record; + tools?: ToolArgument[]; + toolChoice?: ToolChoice; use?: ModelMiddleware[]; } diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index d212afdc9..956d8ea9a 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -594,6 +594,20 @@ export class Genkit implements HasRegistry { } else if (opt.use) { resultOptions.use = opt.use; } + if ((promptResult as GenerateOptions).tools) { + resultOptions.tools = (promptResult as GenerateOptions).tools; + } else if (p.__config?.tools) { + resultOptions.tools = p.__config?.tools; + } else if (opt.tools) { + resultOptions.tools = opt.tools; + } + if ((promptResult as GenerateOptions).toolChoice) { + resultOptions.toolChoice = (promptResult as GenerateOptions).toolChoice; + } else if (p.__config?.toolChoice) { + resultOptions.toolChoice = p.__config?.toolChoice; + } else if (opt.toolChoice) { + resultOptions.toolChoice = opt.toolChoice; + } delete (resultOptions as any).input; if ((promptResult as GenerateOptions).prompt) { resultOptions.prompt = (promptResult as GenerateOptions).prompt; diff --git a/js/genkit/tests/chat_test.ts b/js/genkit/tests/chat_test.ts index 1d6c148de..50e938192 100644 --- a/js/genkit/tests/chat_test.ts +++ b/js/genkit/tests/chat_test.ts @@ -190,6 +190,7 @@ describe('preamble', () => { config: { temperature: 1 }, description: 'Agent B description', tools: ['agentA'], + toolChoice: 'required', }, '{{role "system"}} agent b' ); @@ -200,6 +201,7 @@ describe('preamble', () => { config: { temperature: 2 }, description: 'Agent A description', tools: [agentB], + toolChoice: 'required', }, async () => { return { @@ -218,14 +220,16 @@ describe('preamble', () => { return { message: { role: 'model', - content: [{ text: 'hi from agent a' }], + content: [ + { text: `hi from agent a (toolChoice: ${req.toolChoice})` }, + ], }, }; }; const session = ai.chat(agentA); let { text } = await session.send('hi'); - assert.strictEqual(text, 'hi from agent a'); + assert.strictEqual(text, 'hi from agent a (toolChoice: required)'); assert.deepStrictEqual(pm.lastRequest, { config: { temperature: 2, @@ -254,6 +258,7 @@ describe('preamble', () => { }, }, ], + toolChoice: 'required', }); // transfer to agent B... @@ -273,7 +278,7 @@ describe('preamble', () => { ref: 'ref123', }, } - : { text: 'hi from agent b' }, + : { text: `hi from agent b (toolChoice: ${req.toolChoice})` }, ], }, }; @@ -281,7 +286,7 @@ describe('preamble', () => { ({ text } = await session.send('pls transfer to b')); - assert.deepStrictEqual(text, 'hi from agent b'); + assert.deepStrictEqual(text, 'hi from agent b (toolChoice: required)'); assert.deepStrictEqual(pm.lastRequest, { config: { // TODO: figure out if config should be swapped out as well... @@ -299,7 +304,7 @@ describe('preamble', () => { }, { role: 'model', - content: [{ text: 'hi from agent a' }], + content: [{ text: 'hi from agent a (toolChoice: required)' }], }, { role: 'user', @@ -343,6 +348,7 @@ describe('preamble', () => { }, }, ], + toolChoice: 'required', }); // transfer back to to agent A... @@ -387,7 +393,7 @@ describe('preamble', () => { }, { role: 'model', - content: [{ text: 'hi from agent a' }], + content: [{ text: 'hi from agent a (toolChoice: required)' }], }, { role: 'user', @@ -419,7 +425,7 @@ describe('preamble', () => { }, { role: 'model', - content: [{ text: 'hi from agent b' }], + content: [{ text: 'hi from agent b (toolChoice: required)' }], }, { role: 'user', @@ -463,6 +469,7 @@ describe('preamble', () => { }, }, ], + toolChoice: 'required', }); }); diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 028ed2a38..9d94467c2 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -82,6 +82,29 @@ describe('generate', () => { }); }); + it('calls the default model with tool choice', async () => { + const response = await ai.generate({ + prompt: 'hi', + toolChoice: 'required', + }); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + assert.deepStrictEqual(response.request, { + config: { + version: undefined, + }, + docs: undefined, + messages: [ + { + role: 'user', + content: [{ text: 'hi' }], + }, + ], + output: {}, + tools: [], + toolChoice: 'required', + }); + }); + it('streams the default model', async () => { const { response, stream } = await ai.generateStream('hi'); diff --git a/js/plugins/dotprompt/src/metadata.ts b/js/plugins/dotprompt/src/metadata.ts index 686d0b211..288e89bee 100644 --- a/js/plugins/dotprompt/src/metadata.ts +++ b/js/plugins/dotprompt/src/metadata.ts @@ -52,6 +52,9 @@ export interface PromptMetadata< /** Names of tools (registered separately) to allow use of in this prompt. */ tools?: ToolArgument[]; + /** Specifies how tools should be called by the model. */ + toolChoice?: 'auto' | 'required' | 'none'; + /** Model configuration. Not all models support all options. */ config?: z.infer; diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 83bbeec5c..fe465b8de 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -81,6 +81,7 @@ export class Dotprompt implements PromptMetadata { input?: PromptMetadata['input']; output?: PromptMetadata['output']; tools?: PromptMetadata['tools']; + toolChoice?: PromptMetadata['toolChoice']; config?: PromptMetadata['config']; use?: PromptMetadata['use']; @@ -147,6 +148,7 @@ export class Dotprompt implements PromptMetadata { this.input = options.input || { schema: z.any() }; this.output = options.output; this.tools = options.tools; + this.toolChoice = options.toolChoice; this.config = options.config; this.use = options.use; this.template = template; @@ -266,6 +268,9 @@ export class Dotprompt implements PromptMetadata { if (middleware.length > 0) { res.use = middleware; } + if (options.toolChoice || this.toolChoice) { + res.toolChoice = options.toolChoice ?? this.toolChoice; + } return res; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index 52da58f69..71ca7233f 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -138,11 +138,13 @@ describe('Prompt', () => { input: { name: 'Michael' }, onChunk: streamingCallback, returnToolRequests: true, + toolChoice: 'required', maxTurns: 17, use: middleware, }); assert.strictEqual(rendered.onChunk, streamingCallback); assert.strictEqual(rendered.returnToolRequests, true); + assert.strictEqual(rendered.toolChoice, 'required'); assert.strictEqual(rendered.maxTurns, 17); assert.deepStrictEqual(rendered.use, middleware); }); @@ -505,4 +507,21 @@ describe('DotpromptRef', () => { }, ]); }); + + it('should render system prompt', () => { + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt(registry, model, `hi`, { + toolChoice: 'required', + }); + + const rendered = prompt.render({ input: {} }); + assert.deepStrictEqual(rendered.toolChoice, 'required'); + }); }); diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index ebd65adfe..977d70b37 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -680,19 +680,18 @@ export function defineGoogleAIModel( } let toolConfig: ToolConfig | undefined; - if ( - requestConfig.functionCallingConfig && - // This is a workround for issue: https://github.com/firebase/genkit/issues/1520 - // TODO: remove this when the issue is resolved upstream in the Gemini API - !messages.at(-1)?.content.find((c) => c.toolResponse) - ) { + if (requestConfig.functionCallingConfig) { toolConfig = { functionCallingConfig: { allowedFunctionNames: requestConfig.functionCallingConfig.allowedFunctionNames, - mode: toGeminiFunctionMode( - requestConfig.functionCallingConfig.mode - ), + mode: toFunctionModeEnum(requestConfig.functionCallingConfig.mode), + }, + }; + } else if (request.toolChoice) { + toolConfig = { + functionCallingConfig: { + mode: toGeminiFunctionModeEnum(request.toolChoice), }, }; } @@ -821,13 +820,14 @@ export function defineGoogleAIModel( ); } -function toGeminiFunctionMode( - genkitMode: string | undefined +/** Converts mode from the config, which follows Gemini naming convention. */ +function toFunctionModeEnum( + configEnum: string | undefined ): FunctionCallingMode | undefined { - if (genkitMode === undefined) { + if (configEnum === undefined) { return undefined; } - switch (genkitMode) { + switch (configEnum) { case 'MODE_UNSPECIFIED': { return FunctionCallingMode.MODE_UNSPECIFIED; } @@ -840,6 +840,28 @@ function toGeminiFunctionMode( case 'NONE': { return FunctionCallingMode.NONE; } + default: + throw new Error(`unsupported function calling mode: ${configEnum}`); + } +} + +/** Converts mode from genkit tool choice. */ +function toGeminiFunctionModeEnum( + genkitMode: 'auto' | 'required' | 'none' +): FunctionCallingMode | undefined { + if (genkitMode === undefined) { + return undefined; + } + switch (genkitMode) { + case 'required': { + return FunctionCallingMode.ANY; + } + case 'auto': { + return FunctionCallingMode.AUTO; + } + case 'none': { + return FunctionCallingMode.NONE; + } default: throw new Error(`unsupported function calling mode: ${genkitMode}`); } diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 581001619..1e50fae3d 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -624,19 +624,18 @@ export function defineGeminiModel( : []; let toolConfig: ToolConfig | undefined; - if ( - request?.config?.functionCallingConfig && - // This is a workround for issue: https://github.com/firebase/genkit/issues/1520 - // TODO: remove this when the issue is resolved upstream in the Gemini API - !messages.at(-1)?.content.find((c) => c.toolResponse) - ) { + if (request?.config?.functionCallingConfig) { toolConfig = { functionCallingConfig: { allowedFunctionNames: request.config.functionCallingConfig.allowedFunctionNames, - mode: toGeminiFunctionMode( - request.config.functionCallingConfig.mode - ), + mode: toFunctionModeEnum(request.config.functionCallingConfig.mode), + }, + }; + } else if (request.toolChoice) { + toolConfig = { + functionCallingConfig: { + mode: toGeminiFunctionModeEnum(request.toolChoice), }, }; } @@ -798,13 +797,14 @@ export function defineGeminiModel( ); } -function toGeminiFunctionMode( - genkitMode: string | undefined +/** Converts mode from the config, which follows Gemini naming convention. */ +function toFunctionModeEnum( + enumMode: string | undefined ): FunctionCallingMode | undefined { - if (genkitMode === undefined) { + if (enumMode === undefined) { return undefined; } - switch (genkitMode) { + switch (enumMode) { case 'MODE_UNSPECIFIED': { return FunctionCallingMode.MODE_UNSPECIFIED; } @@ -817,6 +817,28 @@ function toGeminiFunctionMode( case 'NONE': { return FunctionCallingMode.NONE; } + default: + throw new Error(`unsupported function calling mode: ${enumMode}`); + } +} + +/** Converts mode from genkit tool choice. */ +function toGeminiFunctionModeEnum( + genkitMode: 'auto' | 'required' | 'none' +): FunctionCallingMode | undefined { + if (genkitMode === undefined) { + return undefined; + } + switch (genkitMode) { + case 'required': { + return FunctionCallingMode.ANY; + } + case 'auto': { + return FunctionCallingMode.AUTO; + } + case 'none': { + return FunctionCallingMode.NONE; + } default: throw new Error(`unsupported function calling mode: ${genkitMode}`); } diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 8231c82b2..e2a609fc3 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -386,7 +386,20 @@ const jokeSubjectGenerator = ai.defineTool( } ); -export const toolCaller = ai.defineStreamingFlow( +const gablorkenTool = ai.defineTool( + { + name: 'gablorkenTool', + inputSchema: z.object({ + value: z.number(), + }), + description: 'can be used to calculate gablorken value', + }, + async (input) => { + return input.value * 3 - 4; + } +); + +export const toolCaller = ai.defineFlow( { name: 'toolCaller', outputSchema: z.string(), @@ -414,6 +427,49 @@ export const toolCaller = ai.defineStreamingFlow( } ); +const exitTool = ai.defineTool( + { + name: 'exitTool', + inputSchema: z.object({ + answer: z.number(), + }), + description: 'call this tool when you have the final answer', + }, + async (input) => { + throw new Error(`Answer: ${input.answer}`); + } +); + +export const forcedToolCaller = ai.defineFlow( + { + name: 'forcedToolCaller', + inputSchema: z.number(), + outputSchema: z.string(), + streamSchema: z.any(), + }, + async (input, streamingCallback) => { + if (!streamingCallback) { + throw new Error('this flow only works in streaming mode'); + } + + const { response, stream } = await ai.generateStream({ + model: gemini15Flash, + config: { + temperature: 1, + }, + tools: [gablorkenTool, exitTool], + toolChoice: 'required', + prompt: `what is a gablorken of ${input}`, + }); + + for await (const chunk of stream) { + streamingCallback(chunk); + } + + return (await response).text; + } +); + export const invalidOutput = ai.defineFlow( { name: 'invalidOutput',