Skip to content

Commit

Permalink
feat(js/ai): added a simple way to interrupt tool execution (#1583)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Jan 23, 2025
1 parent d88951f commit dead0b8
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 62 deletions.
18 changes: 13 additions & 5 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,22 @@ export const GenerationUsageSchema = z.object({
});
export type GenerationUsage = z.infer<typeof GenerationUsageSchema>;

/** Model response finish reason enum. */
export const FinishReasonSchema = z.enum([
'stop',
'length',
'blocked',
'interrupted',
'other',
'unknown',
]);

/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. */
export const CandidateSchema = z.object({
index: z.number(),
message: MessageSchema,
usage: GenerationUsageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
custom: z.unknown(),
});
Expand All @@ -226,7 +236,7 @@ export type CandidateError = z.infer<typeof CandidateErrorSchema>;

export const ModelResponseSchema = z.object({
message: MessageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
latencyMs: z.number().optional(),
usage: GenerationUsageSchema.optional(),
Expand All @@ -238,9 +248,7 @@ export type ModelResponseData = z.infer<typeof ModelResponseSchema>;
export const GenerateResponseSchema = ModelResponseSchema.extend({
/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. Return `message`, `finishReason`, and `finishMessage` instead. */
candidates: z.array(CandidateSchema).optional(),
finishReason: z
.enum(['stop', 'length', 'blocked', 'other', 'unknown'])
.optional(),
finishReason: FinishReasonSchema.optional(),
});
export type GenerateResponseData = z.infer<typeof GenerateResponseSchema>;

Expand Down
38 changes: 14 additions & 24 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,7 @@
"$ref": "#/$defs/GenerationUsage"
},
"finishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"other",
"unknown"
]
"$ref": "#/$defs/FinishReason"
},
"finishMessage": {
"type": "string"
Expand Down Expand Up @@ -341,6 +334,17 @@
},
"additionalProperties": false
},
"FinishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"interrupted",
"other",
"unknown"
]
},
"GenerateRequest": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -489,14 +493,7 @@
"$ref": "#/$defs/Message"
},
"finishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"other",
"unknown"
]
"$ref": "#/$defs/FinishReason"
},
"finishMessage": {
"type": "string"
Expand Down Expand Up @@ -755,14 +752,7 @@
"$ref": "#/$defs/GenerateResponse/properties/message"
},
"finishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"other",
"unknown"
]
"$ref": "#/$defs/FinishReason"
},
"finishMessage": {
"$ref": "#/$defs/GenerateResponse/properties/finishMessage"
Expand Down
100 changes: 77 additions & 23 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ import {
ToolResponsePart,
resolveModel,
} from '../model.js';
import { ToolAction, resolveTools, toToolDefinition } from '../tool.js';
import {
ToolAction,
ToolInterruptError,
resolveTools,
toToolDefinition,
} from '../tool.js';

export const GenerateUtilParamSchema = z.object({
/** A model name (e.g. `vertexai/gemini-1.0-pro`). */
Expand Down Expand Up @@ -238,6 +243,8 @@ async function generate(
let messages: MessageData[] = [...request.messages, message];
let newTools = options.rawRequest.tools;
let newToolChoice = options.rawRequest.toolChoice;
let interruptedParts: Part[] = [];
let pendingToolRequests: Part[] = [];
for (const part of toolCalls) {
if (!part.toolRequest) {
throw Error(
Expand All @@ -249,29 +256,62 @@ async function generate(
throw Error(`Tool ${part.toolRequest?.name} not found`);
}
if ((tool.__action.metadata.type as string) === 'prompt') {
const newPreamble = await tool(part.toolRequest?.input);
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: `transferred to ${part.toolRequest.name}`,
},
});
// swap out the preamble
messages = [
...tagAsPreamble(newPreamble.messages)!,
...messages.filter((m) => !m?.metadata?.preamble),
];
newTools = newPreamble.tools;
newToolChoice = newPreamble.toolChoice;
try {
const newPreamble = await tool(part.toolRequest?.input);
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: `transferred to ${part.toolRequest.name}`,
},
});
// swap out the preamble
messages = [
...tagAsPreamble(newPreamble.messages)!,
...messages.filter((m) => !m?.metadata?.preamble),
];
newTools = newPreamble.tools;
newToolChoice = newPreamble.toolChoice;
} catch (e) {
if (e instanceof ToolInterruptError) {
logger.debug(`interrupted tool ${part.toolRequest?.name}`);
part.metadata = { ...part.metadata, interrupt: e.metadata || true };
interruptedParts.push(part);
} else {
throw e;
}
}
} else {
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: await tool(part.toolRequest?.input),
},
});
try {
const toolOutput = await tool(part.toolRequest?.input);
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: toolOutput,
},
});
// we prep these in case any other tool gets interrupted.
pendingToolRequests.push({
...part,
metadata: {
...part.metadata,
pendingToolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: toolOutput,
},
},
});
} catch (e) {
if (e instanceof ToolInterruptError) {
logger.debug(`interrupted tool ${part.toolRequest?.name}`);
part.metadata = { ...part.metadata, interrupt: e.metadata || true };
interruptedParts.push(part);
} else {
throw e;
}
}
}
}
options.messageIndex++;
Expand Down Expand Up @@ -301,6 +341,20 @@ async function generate(
}
)
);
if (interruptedParts.length > 0) {
const nonToolParts =
(response.message?.content.filter((c) => !c.toolRequest) as Part[]) || [];
return {
...response.toJSON(),
finishReason: 'interrupted',
message: {
role: 'model',
content: nonToolParts
.concat(pendingToolRequests)
.concat(interruptedParts),
},
};
}
return await generateHelper(registry, {
rawRequest: nextRequest,
middleware: options.middleware,
Expand Down
1 change: 1 addition & 0 deletions js/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export {
type RetrieverReference,
} from './retriever.js';
export {
ToolInterruptError,
asTool,
defineTool,
type ToolAction,
Expand Down
20 changes: 14 additions & 6 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ export type GenerationCommonConfig = typeof GenerationCommonConfigSchema;
/**
* Zod schema of output config.
*/
const OutputConfigSchema = z.object({
export const OutputConfigSchema = z.object({
format: z.string().optional(),
schema: z.record(z.any()).optional(),
constrained: z.boolean().optional(),
Expand Down Expand Up @@ -344,12 +344,22 @@ export const GenerationUsageSchema = z.object({
*/
export type GenerationUsage = z.infer<typeof GenerationUsageSchema>;

/** Model response finish reason enum. */
const FinishReasonSchema = z.enum([
'stop',
'length',
'blocked',
'interrupted',
'other',
'unknown',
]);

/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. */
export const CandidateSchema = z.object({
index: z.number(),
message: MessageSchema,
usage: GenerationUsageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
custom: z.unknown(),
});
Expand All @@ -370,7 +380,7 @@ export type CandidateError = z.infer<typeof CandidateErrorSchema>;
*/
export const ModelResponseSchema = z.object({
message: MessageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
latencyMs: z.number().optional(),
usage: GenerationUsageSchema.optional(),
Expand All @@ -391,9 +401,7 @@ export type ModelResponseData = z.infer<typeof ModelResponseSchema>;
export const GenerateResponseSchema = ModelResponseSchema.extend({
/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. Return `message`, `finishReason`, and `finishMessage` instead. */
candidates: z.array(CandidateSchema).optional(),
finishReason: z
.enum(['stop', 'length', 'blocked', 'other', 'unknown'])
.optional(),
finishReason: FinishReasonSchema.optional(),
});

/**
Expand Down
37 changes: 35 additions & 2 deletions js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ export function toToolDefinition(
return out;
}

export interface ToolFnOptions {
/**
* A function that can be called during tool execution that will result in the tool
* getting interrupted (immediately) and tool request returned to the upstream caller.
*/
interrupt: (metadata?: Record<string, any>) => void;
}

export type ToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
input: z.infer<I>,
ctx: ToolFnOptions
) => Promise<z.infer<O>>;

/**
* Defines a tool.
*
Expand All @@ -174,7 +187,7 @@ export function toToolDefinition(
export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
registry: Registry,
config: ToolConfig<I, O>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
fn: ToolFn<I, O>
): ToolAction<I, O> {
const a = defineAction(
registry,
Expand All @@ -183,7 +196,27 @@ export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
actionType: 'tool',
metadata: { ...(config.metadata || {}), type: 'tool' },
},
(i) => fn(i)
(i) =>
fn(i, {
interrupt: interruptTool,
})
);
return a as ToolAction<I, O>;
}

/**
* Thrown when tools execution is interrupted. It's meant to be caugh by the framework, not public API.
*/
export class ToolInterruptError extends Error {
constructor(readonly metadata?: Record<string, any>) {
super();
}
}

/**
* Interrupts current tool execution causing tool request to be returned in the generation response.
* Should only be called within a tool.
*/
function interruptTool(metadata?: Record<string, any>) {
throw new ToolInterruptError(metadata);
}
4 changes: 2 additions & 2 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ import {
SessionError,
SessionOptions,
} from '@genkit-ai/ai/session';
import { resolveTools } from '@genkit-ai/ai/tool';
import { resolveTools, ToolFn } from '@genkit-ai/ai/tool';
import {
Action,
defineFlow,
Expand Down Expand Up @@ -218,7 +218,7 @@ export class Genkit implements HasRegistry {
*/
defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
fn: ToolFn<I, O>
): ToolAction<I, O> {
return defineTool(this.registry, config, fn);
}
Expand Down
Loading

0 comments on commit dead0b8

Please sign in to comment.