Skip to content

Commit

Permalink
feat(js/ai): added toolChoice option to generate requests (including …
Browse files Browse the repository at this point in the history
…chat and prompt)
  • Loading branch information
pavelgj committed Jan 4, 2025
1 parent 99eb347 commit 368daa6
Show file tree
Hide file tree
Showing 15 changed files with 248 additions and 47 deletions.
2 changes: 2 additions & 0 deletions js/ai/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ export class Chat {
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: response?.request?.tools,
toolChoice: response?.request?.toolChoice,
config: response?.request?.config,
});
await this.updateMessages(response.messages);
Expand Down Expand Up @@ -227,6 +228,7 @@ export class Chat {
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: resolvedResponse?.request?.tools,
toolChoice: resolvedResponse?.request?.toolChoice,
config: resolvedResponse?.request?.config,
});
this.updateMessages(resolvedResponse.messages);
Expand Down
6 changes: 6 additions & 0 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ import { ExecutablePrompt } from './prompt.js';
import { ToolArgument, resolveTools, toToolDefinition } from './tool.js';
export { GenerateResponse, GenerateResponseChunk };

/** Specifies how tools should be called by the model. */
export type ToolChoice = 'auto' | 'required' | 'none';

export interface OutputOptions<O extends z.ZodTypeAny = z.ZodTypeAny> {
format?: string;
contentType?: string;
Expand All @@ -71,6 +74,8 @@ export interface GenerateOptions<
messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[];
/** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */
tools?: ToolArgument[];
/** Specifies how tools should be called by the model. */
toolChoice?: ToolChoice;
/** Configuration for the generation request. */
config?: z.infer<CustomOptions>;
/** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */
Expand Down Expand Up @@ -274,6 +279,7 @@ export async function generate<
docs: resolvedOptions.docs,
messages: injectInstructions(messages, instructions),
tools,
toolChoice: resolvedOptions.toolChoice,
config: {
version: resolvedModel.version,
...stripUndefinedOptions(resolvedModel.config),
Expand Down
32 changes: 28 additions & 4 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ import {
GenerateResponseData,
MessageData,
MessageSchema,
ModelInfo,
ModelMiddleware,
ModelRequest,
Part,
Role,
ToolDefinitionSchema,
Expand All @@ -59,6 +61,8 @@ export const GenerateUtilParamSchema = z.object({
messages: z.array(MessageSchema),
/** List of registered tool names for this generation if supported by the underlying model. */
tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(),
/** Tool calling mode. */
toolChoice: z.enum(['auto', 'required', 'none']).optional(),
/** Configuration for the generation request. */
config: z.any().optional(),
/** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */
Expand Down Expand Up @@ -141,7 +145,8 @@ async function generate(
const request = await actionToGenerateRequest(
rawRequest,
tools,
resolvedFormat
resolvedFormat,
model.__action.metadata?.model as ModelInfo
);

const accumulatedChunks: GenerateResponseChunkData[] = [];
Expand Down Expand Up @@ -210,6 +215,7 @@ async function generate(
const toolResponses: ToolResponsePart[] = [];
let messages: MessageData[] = [...request.messages, message];
let newTools = rawRequest.tools;
let newToolChoice = rawRequest.toolChoice;
for (const part of toolCalls) {
if (!part.toolRequest) {
throw Error(
Expand All @@ -235,6 +241,7 @@ async function generate(
...messages.filter((m) => !m?.metadata?.preamble),
];
newTools = newPreamble.tools;
newToolChoice = newPreamble.toolChoice;
} else {
toolResponses.push({
toolResponse: {
Expand All @@ -255,6 +262,7 @@ async function generate(
},
] as MessageData[],
tools: newTools,
toolCoice: newToolChoice,
};
return await generateHelper(
registry,
Expand All @@ -267,9 +275,22 @@ async function generate(
async function actionToGenerateRequest(
options: z.infer<typeof GenerateUtilParamSchema>,
resolvedTools?: ToolAction[],
resolvedFormat?: Formatter
resolvedFormat?: Formatter,
modelInfo?: ModelInfo
): Promise<GenerateRequest> {
const out = {
if ((options.tools?.length ?? 0) > 0 && !modelInfo?.supports?.tools) {
logger.warn(
`The model does not support tools (you set: ${options.tools?.length} tools). ` +
'The model may not behave the way you expect.'
);
}
if (options.toolChoice && !modelInfo?.supports?.toolChoice) {
logger.warn(
`The model does not support toolChoice option (you set: ${options.toolChoice}). ` +
'The model may not behave the way you expect.'
);
}
const out: ModelRequest = {
messages: options.messages,
config: options.config,
docs: options.docs,
Expand All @@ -281,7 +302,10 @@ async function actionToGenerateRequest(
}),
},
};
if (!out.output.schema) delete out.output.schema;
if (options.toolChoice) {
out.toolChoice = options.toolChoice;
}
if (out.output && !out.output.schema) delete out.output.schema;
return out;
}

Expand Down
8 changes: 0 additions & 8 deletions js/ai/src/generate/response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,6 @@ export class GenerateResponse<O = unknown> implements ModelResponseData {
this.request = options?.request;
}

private get assertMessage(): Message<O> {
if (!this.message)
throw new Error(
'Operation could not be completed because the response does not contain a generated message.'
);
return this.message;
}

/**
* Throws an error if the response does not contain valid output.
*/
Expand Down
3 changes: 3 additions & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ export const ModelInfoSchema = z.object({
contentType: z.array(z.string()).optional(),
/** Model can natively support document-based context grounding. */
context: z.boolean().optional(),
/** Model supports controlling tool choice, e.g. forced tool calling. */
toolChoice: z.boolean().optional(),
})
.optional(),
/** At which stage of development this model is.
Expand Down Expand Up @@ -287,6 +289,7 @@ export const ModelRequestSchema = z.object({
messages: z.array(MessageSchema),
config: z.any().optional(),
tools: z.array(ToolDefinitionSchema).optional(),
toolChoice: z.enum(['auto', 'required', 'none']).optional(),
output: OutputConfigSchema.optional(),
docs: z.array(DocumentDataSchema).optional(),
});
Expand Down
5 changes: 4 additions & 1 deletion js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
GenerateOptions,
GenerateResponse,
GenerateStreamResponse,
ToolChoice,
} from './generate.js';
import {
GenerateRequest,
Expand All @@ -29,7 +30,7 @@ import {
ModelArgument,
ModelMiddleware,
} from './model.js';
import { ToolAction } from './tool.js';
import { ToolAction, ToolArgument } from './tool.js';

/**
* Prompt implementation function signature.
Expand Down Expand Up @@ -64,6 +65,8 @@ export interface PromptConfig<I extends z.ZodTypeAny = z.ZodTypeAny> {
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
metadata?: Record<string, any>;
tools?: ToolArgument[];
toolChoice?: ToolChoice;
use?: ModelMiddleware[];
}

Expand Down
14 changes: 14 additions & 0 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,20 @@ export class Genkit implements HasRegistry {
} else if (opt.use) {
resultOptions.use = opt.use;
}
if ((promptResult as GenerateOptions).tools) {
resultOptions.tools = (promptResult as GenerateOptions).tools;
} else if (p.__config?.tools) {
resultOptions.tools = p.__config?.tools;
} else if (opt.tools) {
resultOptions.tools = opt.tools;
}
if ((promptResult as GenerateOptions).toolChoice) {
resultOptions.toolChoice = (promptResult as GenerateOptions).toolChoice;
} else if (p.__config?.toolChoice) {
resultOptions.toolChoice = p.__config?.toolChoice;
} else if (opt.toolChoice) {
resultOptions.toolChoice = opt.toolChoice;
}
delete (resultOptions as any).input;
if ((promptResult as GenerateOptions).prompt) {
resultOptions.prompt = (promptResult as GenerateOptions).prompt;
Expand Down
21 changes: 14 additions & 7 deletions js/genkit/tests/chat_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ describe('preamble', () => {
config: { temperature: 1 },
description: 'Agent B description',
tools: ['agentA'],
toolChoice: 'required',
},
'{{role "system"}} agent b'
);
Expand All @@ -200,6 +201,7 @@ describe('preamble', () => {
config: { temperature: 2 },
description: 'Agent A description',
tools: [agentB],
toolChoice: 'required',
},
async () => {
return {
Expand All @@ -218,14 +220,16 @@ describe('preamble', () => {
return {
message: {
role: 'model',
content: [{ text: 'hi from agent a' }],
content: [
{ text: `hi from agent a (toolChoice: ${req.toolChoice})` },
],
},
};
};

const session = ai.chat(agentA);
let { text } = await session.send('hi');
assert.strictEqual(text, 'hi from agent a');
assert.strictEqual(text, 'hi from agent a (toolChoice: required)');
assert.deepStrictEqual(pm.lastRequest, {
config: {
temperature: 2,
Expand Down Expand Up @@ -254,6 +258,7 @@ describe('preamble', () => {
},
},
],
toolChoice: 'required',
});

// transfer to agent B...
Expand All @@ -273,15 +278,15 @@ describe('preamble', () => {
ref: 'ref123',
},
}
: { text: 'hi from agent b' },
: { text: `hi from agent b (toolChoice: ${req.toolChoice})` },
],
},
};
};

({ text } = await session.send('pls transfer to b'));

assert.deepStrictEqual(text, 'hi from agent b');
assert.deepStrictEqual(text, 'hi from agent b (toolChoice: required)');
assert.deepStrictEqual(pm.lastRequest, {
config: {
// TODO: figure out if config should be swapped out as well...
Expand All @@ -299,7 +304,7 @@ describe('preamble', () => {
},
{
role: 'model',
content: [{ text: 'hi from agent a' }],
content: [{ text: 'hi from agent a (toolChoice: required)' }],
},
{
role: 'user',
Expand Down Expand Up @@ -343,6 +348,7 @@ describe('preamble', () => {
},
},
],
toolChoice: 'required',
});

// transfer back to to agent A...
Expand Down Expand Up @@ -387,7 +393,7 @@ describe('preamble', () => {
},
{
role: 'model',
content: [{ text: 'hi from agent a' }],
content: [{ text: 'hi from agent a (toolChoice: required)' }],
},
{
role: 'user',
Expand Down Expand Up @@ -419,7 +425,7 @@ describe('preamble', () => {
},
{
role: 'model',
content: [{ text: 'hi from agent b' }],
content: [{ text: 'hi from agent b (toolChoice: required)' }],
},
{
role: 'user',
Expand Down Expand Up @@ -463,6 +469,7 @@ describe('preamble', () => {
},
},
],
toolChoice: 'required',
});
});

Expand Down
23 changes: 23 additions & 0 deletions js/genkit/tests/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,29 @@ describe('generate', () => {
});
});

it('calls the default model with tool choice', async () => {
const response = await ai.generate({
prompt: 'hi',
toolChoice: 'required',
});
assert.strictEqual(response.text, 'Echo: hi; config: {}');
assert.deepStrictEqual(response.request, {
config: {
version: undefined,
},
docs: undefined,
messages: [
{
role: 'user',
content: [{ text: 'hi' }],
},
],
output: {},
tools: [],
toolChoice: 'required',
});
});

it('streams the default model', async () => {
const { response, stream } = await ai.generateStream('hi');

Expand Down
3 changes: 3 additions & 0 deletions js/plugins/dotprompt/src/metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ export interface PromptMetadata<
/** Names of tools (registered separately) to allow use of in this prompt. */
tools?: ToolArgument[];

/** Specifies how tools should be called by the model. */
toolChoice?: 'auto' | 'required' | 'none';

/** Model configuration. Not all models support all options. */
config?: z.infer<Options>;

Expand Down
5 changes: 5 additions & 0 deletions js/plugins/dotprompt/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ export class Dotprompt<I = unknown> implements PromptMetadata<z.ZodTypeAny> {
input?: PromptMetadata['input'];
output?: PromptMetadata['output'];
tools?: PromptMetadata['tools'];
toolChoice?: PromptMetadata['toolChoice'];
config?: PromptMetadata['config'];
use?: PromptMetadata['use'];

Expand Down Expand Up @@ -147,6 +148,7 @@ export class Dotprompt<I = unknown> implements PromptMetadata<z.ZodTypeAny> {
this.input = options.input || { schema: z.any() };
this.output = options.output;
this.tools = options.tools;
this.toolChoice = options.toolChoice;
this.config = options.config;
this.use = options.use;
this.template = template;
Expand Down Expand Up @@ -266,6 +268,9 @@ export class Dotprompt<I = unknown> implements PromptMetadata<z.ZodTypeAny> {
if (middleware.length > 0) {
res.use = middleware;
}
if (options.toolChoice || this.toolChoice) {
res.toolChoice = options.toolChoice ?? this.toolChoice;
}
return res;
}

Expand Down
Loading

0 comments on commit 368daa6

Please sign in to comment.