Skip to content

Commit

Permalink
feat(js/ai): added toolChoice option to generate requests (including …
Browse files Browse the repository at this point in the history
…chat and prompt) (#1571)
  • Loading branch information
pavelgj authored Jan 23, 2025
1 parent 0a242b9 commit d88951f
Show file tree
Hide file tree
Showing 16 changed files with 331 additions and 83 deletions.
1 change: 1 addition & 0 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ export const ModelRequestSchema = z.object({
messages: z.array(MessageSchema),
config: z.any().optional(),
tools: z.array(ToolDefinitionSchema).optional(),
toolChoice: z.enum(['auto', 'required', 'none']).optional(),
output: OutputConfigSchema.optional(),
context: z.array(DocumentDataSchema).optional(),
});
Expand Down
11 changes: 11 additions & 0 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@
"$ref": "#/$defs/ToolDefinition"
}
},
"toolChoice": {
"type": "string",
"enum": [
"auto",
"required",
"none"
]
},
"output": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -698,6 +706,9 @@
"tools": {
"$ref": "#/$defs/GenerateRequest/properties/tools"
},
"toolChoice": {
"$ref": "#/$defs/GenerateRequest/properties/toolChoice"
},
"output": {
"$ref": "#/$defs/GenerateRequest/properties/output"
},
Expand Down
2 changes: 2 additions & 0 deletions js/ai/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ export class Chat {
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: response?.request?.tools,
toolChoice: response?.request?.toolChoice,
config: response?.request?.config,
});
await this.updateMessages(response.messages);
Expand Down Expand Up @@ -252,6 +253,7 @@ export class Chat {
...(await this.requestBase),
// these things may get changed by tools calling within generate.
tools: resolvedResponse?.request?.tools,
toolChoice: resolvedResponse?.request?.toolChoice,
config: resolvedResponse?.request?.config,
});
this.updateMessages(resolvedResponse.messages);
Expand Down
15 changes: 10 additions & 5 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ import { ExecutablePrompt } from './prompt.js';
import { ToolArgument, resolveTools, toToolDefinition } from './tool.js';
export { GenerateResponse, GenerateResponseChunk };

/** Specifies how tools should be called by the model. */
export type ToolChoice = 'auto' | 'required' | 'none';

export interface OutputOptions<O extends z.ZodTypeAny = z.ZodTypeAny> {
format?: string;
contentType?: string;
Expand All @@ -72,6 +75,8 @@ export interface GenerateOptions<
messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[];
/** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */
tools?: ToolArgument[];
/** Specifies how tools should be called by the model. */
toolChoice?: ToolChoice;
/** Configuration for the generation request. */
config?: z.infer<CustomOptions>;
/** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */
Expand Down Expand Up @@ -275,6 +280,7 @@ export async function generate<
docs: resolvedOptions.docs,
messages: injectInstructions(messages, instructions),
tools,
toolChoice: resolvedOptions.toolChoice,
config: {
version: resolvedModel.version,
...stripUndefinedOptions(resolvedModel.config),
Expand All @@ -292,11 +298,10 @@ export async function generate<
registry,
stripNoop(resolvedOptions.onChunk ?? resolvedOptions.streamingCallback),
async () => {
const response = await generateHelper(
registry,
params,
resolvedOptions.use
);
const response = await generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
});
const request = await toGenerateRequest(registry, {
...resolvedOptions,
tools,
Expand Down
130 changes: 86 additions & 44 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ import {
GenerateResponseData,
MessageData,
MessageSchema,
ModelAction,
ModelInfo,
ModelMiddleware,
ModelRequest,
Part,
Role,
ToolDefinitionSchema,
Expand All @@ -59,6 +62,8 @@ export const GenerateUtilParamSchema = z.object({
messages: z.array(MessageSchema),
/** List of registered tool names for this generation if supported by the underlying model. */
tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(),
/** Tool calling mode. `auto` lets the model decide whether to use tools, `required` forces the model to choose a tool, and `none` forces the model not to use any tools. Defaults to `auto`. */
toolChoice: z.enum(['auto', 'required', 'none']).optional(),
/** Configuration for the generation request. */
config: z.any().optional(),
/** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */
Expand All @@ -81,13 +86,15 @@ export const GenerateUtilParamSchema = z.object({
*/
export async function generateHelper(
registry: Registry,
input: z.infer<typeof GenerateUtilParamSchema>,
middleware?: ModelMiddleware[],
currentTurns?: number,
messageIndex?: number
options: {
rawRequest: z.infer<typeof GenerateUtilParamSchema>;
middleware?: ModelMiddleware[];
currentTurn?: number;
messageIndex?: number;
}
): Promise<GenerateResponseData> {
currentTurns = currentTurns ?? 0;
messageIndex = messageIndex ?? 0;
let currentTurn = options.currentTurn ?? 0;
let messageIndex = options.messageIndex ?? 0;
// do tracing
return await runInNewSpan(
registry,
Expand All @@ -101,14 +108,13 @@ export async function generateHelper(
},
async (metadata) => {
metadata.name = 'generate';
metadata.input = input;
const output = await generate(
registry,
input,
middleware,
currentTurns!,
messageIndex!
);
metadata.input = options.rawRequest;
const output = await generate(registry, {
rawRequest: options.rawRequest,
middleware: options.middleware,
currentTurn,
messageIndex,
});
metadata.output = JSON.stringify(output);
return output;
}
Expand All @@ -117,22 +123,30 @@ export async function generateHelper(

async function generate(
registry: Registry,
rawRequest: z.infer<typeof GenerateUtilParamSchema>,
middleware: ModelMiddleware[] | undefined,
currentTurn: number,
messageIndex: number
options: {
rawRequest: z.infer<typeof GenerateUtilParamSchema>;
middleware: ModelMiddleware[] | undefined;
currentTurn: number;
messageIndex: number;
}
): Promise<GenerateResponseData> {
const { modelAction: model } = await resolveModel(registry, rawRequest.model);
const { modelAction: model } = await resolveModel(
registry,
options.rawRequest.model
);
if (model.__action.metadata?.model.stage === 'deprecated') {
logger.warn(
`${clc.bold(clc.yellow('Warning:'))} ` +
`Model '${model.__action.name}' is deprecated and may be removed in a future release.`
);
}

const tools = await resolveTools(registry, rawRequest.tools);
const tools = await resolveTools(registry, options.rawRequest.tools);

const resolvedFormat = await resolveFormat(registry, rawRequest.output);
const resolvedFormat = await resolveFormat(
registry,
options.rawRequest.output
);
// Create a lookup of tool names with namespaces stripped to original names
const toolMap = tools.reduce<Record<string, ToolAction>>((acc, tool) => {
const name = tool.__action.name;
Expand All @@ -148,9 +162,10 @@ async function generate(
}, {});

const request = await actionToGenerateRequest(
rawRequest,
options.rawRequest,
tools,
resolvedFormat
resolvedFormat,
model
);

const accumulatedChunks: GenerateResponseChunkData[] = [];
Expand All @@ -164,7 +179,7 @@ async function generate(
if (streamingCallback) {
streamingCallback!(
new GenerateResponseChunk(chunk, {
index: messageIndex,
index: options.messageIndex,
role: 'model',
previousChunks: accumulatedChunks,
parser: resolvedFormat?.handler(request.output?.schema)
Expand All @@ -180,12 +195,12 @@ async function generate(
index: number,
req: z.infer<typeof GenerateRequestSchema>
) => {
if (!middleware || index === middleware.length) {
if (!options.middleware || index === options.middleware.length) {
// end of the chain, call the original model action
return await model(req);
}

const currentMiddleware = middleware[index];
const currentMiddleware = options.middleware[index];
return currentMiddleware(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req)
);
Expand All @@ -203,14 +218,14 @@ async function generate(
const message = response.message!; // would have thrown if no message

const toolCalls = message.content.filter((part) => !!part.toolRequest);
if (rawRequest.returnToolRequests || toolCalls.length === 0) {
if (options.rawRequest.returnToolRequests || toolCalls.length === 0) {
if (toolCalls.length === 0) {
response.assertValidSchema(request);
}
return response.toJSON();
}
const maxIterations = rawRequest.maxTurns ?? 5;
if (currentTurn + 1 > maxIterations) {
const maxIterations = options.rawRequest.maxTurns ?? 5;
if (options.currentTurn + 1 > maxIterations) {
throw new GenerationResponseError(
response,
`Exceeded maximum tool call iterations (${maxIterations})`,
Expand All @@ -221,7 +236,8 @@ async function generate(

const toolResponses: ToolResponsePart[] = [];
let messages: MessageData[] = [...request.messages, message];
let newTools = rawRequest.tools;
let newTools = options.rawRequest.tools;
let newToolChoice = options.rawRequest.toolChoice;
for (const part of toolCalls) {
if (!part.toolRequest) {
throw Error(
Expand All @@ -247,6 +263,7 @@ async function generate(
...messages.filter((m) => !m?.metadata?.preamble),
];
newTools = newPreamble.tools;
newToolChoice = newPreamble.toolChoice;
} else {
toolResponses.push({
toolResponse: {
Expand All @@ -257,9 +274,9 @@ async function generate(
});
}
}
messageIndex++;
options.messageIndex++;
const nextRequest = {
...rawRequest,
...options.rawRequest,
messages: [
...messages,
{
Expand All @@ -268,6 +285,7 @@ async function generate(
},
] as MessageData[],
tools: newTools,
toolCoice: newToolChoice,
};
// stream out the tool responses
streamingCallback?.(
Expand All @@ -276,28 +294,49 @@ async function generate(
content: toolResponses,
},
{
index: messageIndex,
index: options.messageIndex,
role: 'model',
previousChunks: accumulatedChunks,
parser: resolvedFormat?.handler(request.output?.schema).parseChunk,
}
)
);
return await generateHelper(
registry,
nextRequest,
middleware,
currentTurn + 1,
messageIndex + 1
);
return await generateHelper(registry, {
rawRequest: nextRequest,
middleware: options.middleware,
currentTurn: options.currentTurn + 1,
messageIndex: options.messageIndex + 1,
});
}

async function actionToGenerateRequest(
options: z.infer<typeof GenerateUtilParamSchema>,
resolvedTools?: ToolAction[],
resolvedFormat?: Formatter
resolvedTools: ToolAction[] | undefined,
resolvedFormat: Formatter | undefined,
model: ModelAction
): Promise<GenerateRequest> {
const out = {
const modelInfo = model.__action.metadata?.model as ModelInfo;
if (
(options.tools?.length ?? 0) > 0 &&
modelInfo?.supports &&
!modelInfo?.supports?.tools
) {
logger.warn(
`The model '${model.__action.name}' does not support tools (you set: ${options.tools?.length} tools). ` +
'The model may not behave the way you expect.'
);
}
if (
options.toolChoice &&
modelInfo?.supports &&
!modelInfo?.supports?.toolChoice
) {
logger.warn(
`The model '${model.__action.name}' does not support the 'toolChoice' option (you set: ${options.toolChoice}). ` +
'The model may not behave the way you expect.'
);
}
const out: ModelRequest = {
messages: options.messages,
config: options.config,
docs: options.docs,
Expand All @@ -309,7 +348,10 @@ async function actionToGenerateRequest(
}),
},
};
if (!out.output.schema) delete out.output.schema;
if (options.toolChoice) {
out.toolChoice = options.toolChoice;
}
if (out.output && !out.output.schema) delete out.output.schema;
return out;
}

Expand Down
3 changes: 3 additions & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ export const ModelInfoSchema = z.object({
contentType: z.array(z.string()).optional(),
/** Model can natively support document-based context grounding. */
context: z.boolean().optional(),
/** Model supports controlling tool choice, e.g. forced tool calling. */
toolChoice: z.boolean().optional(),
})
.optional(),
/** At which stage of development this model is.
Expand Down Expand Up @@ -287,6 +289,7 @@ export const ModelRequestSchema = z.object({
messages: z.array(MessageSchema),
config: z.any().optional(),
tools: z.array(ToolDefinitionSchema).optional(),
toolChoice: z.enum(['auto', 'required', 'none']).optional(),
output: OutputConfigSchema.optional(),
docs: z.array(DocumentDataSchema).optional(),
});
Expand Down
Loading

0 comments on commit d88951f

Please sign in to comment.