Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js/ai): added a simple way to interrupt tool execution #1583

Merged
merged 9 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,22 @@ export const GenerationUsageSchema = z.object({
});
export type GenerationUsage = z.infer<typeof GenerationUsageSchema>;

/** Model response finish reason enum. */
export const FinishReasonSchema = z.enum([
'stop',
'length',
'blocked',
'interrupted',
'other',
'unknown',
]);

/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. */
export const CandidateSchema = z.object({
index: z.number(),
message: MessageSchema,
usage: GenerationUsageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
custom: z.unknown(),
});
Expand All @@ -226,7 +236,7 @@ export type CandidateError = z.infer<typeof CandidateErrorSchema>;

export const ModelResponseSchema = z.object({
message: MessageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
latencyMs: z.number().optional(),
usage: GenerationUsageSchema.optional(),
Expand All @@ -238,9 +248,7 @@ export type ModelResponseData = z.infer<typeof ModelResponseSchema>;
export const GenerateResponseSchema = ModelResponseSchema.extend({
/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. Return `message`, `finishReason`, and `finishMessage` instead. */
candidates: z.array(CandidateSchema).optional(),
finishReason: z
.enum(['stop', 'length', 'blocked', 'other', 'unknown'])
.optional(),
finishReason: FinishReasonSchema.optional(),
});
export type GenerateResponseData = z.infer<typeof GenerateResponseSchema>;

Expand Down
38 changes: 14 additions & 24 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,7 @@
"$ref": "#/$defs/GenerationUsage"
},
"finishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"other",
"unknown"
]
"$ref": "#/$defs/FinishReason"
},
"finishMessage": {
"type": "string"
Expand Down Expand Up @@ -341,6 +334,17 @@
},
"additionalProperties": false
},
"FinishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"interrupted",
"other",
"unknown"
]
},
"GenerateRequest": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -489,14 +493,7 @@
"$ref": "#/$defs/Message"
},
"finishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"other",
"unknown"
]
"$ref": "#/$defs/FinishReason"
},
"finishMessage": {
"type": "string"
Expand Down Expand Up @@ -755,14 +752,7 @@
"$ref": "#/$defs/GenerateResponse/properties/message"
},
"finishReason": {
"type": "string",
"enum": [
"stop",
"length",
"blocked",
"other",
"unknown"
]
"$ref": "#/$defs/FinishReason"
},
"finishMessage": {
"$ref": "#/$defs/GenerateResponse/properties/finishMessage"
Expand Down
100 changes: 77 additions & 23 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ import {
ToolResponsePart,
resolveModel,
} from '../model.js';
import { ToolAction, resolveTools, toToolDefinition } from '../tool.js';
import {
ToolAction,
ToolInterruptError,
resolveTools,
toToolDefinition,
} from '../tool.js';

export const GenerateUtilParamSchema = z.object({
/** A model name (e.g. `vertexai/gemini-1.0-pro`). */
Expand Down Expand Up @@ -238,6 +243,8 @@ async function generate(
let messages: MessageData[] = [...request.messages, message];
let newTools = options.rawRequest.tools;
let newToolChoice = options.rawRequest.toolChoice;
let interruptedParts: Part[] = [];
let pendingToolRequests: Part[] = [];
for (const part of toolCalls) {
if (!part.toolRequest) {
throw Error(
Expand All @@ -249,29 +256,62 @@ async function generate(
throw Error(`Tool ${part.toolRequest?.name} not found`);
}
if ((tool.__action.metadata.type as string) === 'prompt') {
const newPreamble = await tool(part.toolRequest?.input);
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: `transferred to ${part.toolRequest.name}`,
},
});
// swap out the preamble
messages = [
...tagAsPreamble(newPreamble.messages)!,
...messages.filter((m) => !m?.metadata?.preamble),
];
newTools = newPreamble.tools;
newToolChoice = newPreamble.toolChoice;
try {
const newPreamble = await tool(part.toolRequest?.input);
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: `transferred to ${part.toolRequest.name}`,
},
});
// swap out the preamble
messages = [
...tagAsPreamble(newPreamble.messages)!,
...messages.filter((m) => !m?.metadata?.preamble),
];
newTools = newPreamble.tools;
newToolChoice = newPreamble.toolChoice;
} catch (e) {
if (e instanceof ToolInterruptError) {
logger.debug(`interrupted tool ${part.toolRequest?.name}`);
part.metadata = { ...part.metadata, interrupt: e.metadata || true };
interruptedParts.push(part);
} else {
throw e;
}
}
} else {
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: await tool(part.toolRequest?.input),
},
});
try {
const toolOutput = await tool(part.toolRequest?.input);
toolResponses.push({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: toolOutput,
},
});
// we prep these in case any other tool gets interrupted.
pendingToolRequests.push({
...part,
metadata: {
...part.metadata,
pendingToolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: toolOutput,
},
},
});
} catch (e) {
if (e instanceof ToolInterruptError) {
logger.debug(`interrupted tool ${part.toolRequest?.name}`);
part.metadata = { ...part.metadata, interrupt: e.metadata || true };
interruptedParts.push(part);
} else {
throw e;
}
}
}
}
options.messageIndex++;
Expand Down Expand Up @@ -301,6 +341,20 @@ async function generate(
}
)
);
if (interruptedParts.length > 0) {
const nonToolParts =
(response.message?.content.filter((c) => !c.toolRequest) as Part[]) || [];
return {
...response.toJSON(),
finishReason: 'interrupted',
message: {
role: 'model',
content: nonToolParts
.concat(pendingToolRequests)
.concat(interruptedParts),
},
};
}
return await generateHelper(registry, {
rawRequest: nextRequest,
middleware: options.middleware,
Expand Down
1 change: 1 addition & 0 deletions js/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export {
type RetrieverReference,
} from './retriever.js';
export {
ToolInterruptError,
asTool,
defineTool,
type ToolAction,
Expand Down
20 changes: 14 additions & 6 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ export type GenerationCommonConfig = typeof GenerationCommonConfigSchema;
/**
* Zod schema of output config.
*/
const OutputConfigSchema = z.object({
export const OutputConfigSchema = z.object({
format: z.string().optional(),
schema: z.record(z.any()).optional(),
constrained: z.boolean().optional(),
Expand Down Expand Up @@ -344,12 +344,22 @@ export const GenerationUsageSchema = z.object({
*/
export type GenerationUsage = z.infer<typeof GenerationUsageSchema>;

/** Model response finish reason enum. */
const FinishReasonSchema = z.enum([
'stop',
'length',
'blocked',
'interrupted',
'other',
'unknown',
]);

/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. */
export const CandidateSchema = z.object({
index: z.number(),
message: MessageSchema,
usage: GenerationUsageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
custom: z.unknown(),
});
Expand All @@ -370,7 +380,7 @@ export type CandidateError = z.infer<typeof CandidateErrorSchema>;
*/
export const ModelResponseSchema = z.object({
message: MessageSchema.optional(),
finishReason: z.enum(['stop', 'length', 'blocked', 'other', 'unknown']),
finishReason: FinishReasonSchema,
finishMessage: z.string().optional(),
latencyMs: z.number().optional(),
usage: GenerationUsageSchema.optional(),
Expand All @@ -391,9 +401,7 @@ export type ModelResponseData = z.infer<typeof ModelResponseSchema>;
export const GenerateResponseSchema = ModelResponseSchema.extend({
/** @deprecated All responses now return a single candidate. Only the first candidate will be used if supplied. Return `message`, `finishReason`, and `finishMessage` instead. */
candidates: z.array(CandidateSchema).optional(),
finishReason: z
.enum(['stop', 'length', 'blocked', 'other', 'unknown'])
.optional(),
finishReason: FinishReasonSchema.optional(),
});

/**
Expand Down
37 changes: 35 additions & 2 deletions js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ export function toToolDefinition(
return out;
}

export interface ToolFnOptions {
/**
* A function that can be called during tool execution that will result in the tool
* getting interrupted (immediately) and tool request returned to the upstream caller.
*/
interrupt: (metadata?: Record<string, any>) => void;
}

export type ToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
input: z.infer<I>,
ctx: ToolFnOptions
) => Promise<z.infer<O>>;

/**
* Defines a tool.
*
Expand All @@ -174,7 +187,7 @@ export function toToolDefinition(
export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
registry: Registry,
config: ToolConfig<I, O>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
fn: ToolFn<I, O>
): ToolAction<I, O> {
const a = defineAction(
registry,
Expand All @@ -183,7 +196,27 @@ export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
actionType: 'tool',
metadata: { ...(config.metadata || {}), type: 'tool' },
},
(i) => fn(i)
(i) =>
fn(i, {
interrupt: interruptTool,
})
);
return a as ToolAction<I, O>;
}

/**
* Thrown when tools execution is interrupted. It's meant to be caugh by the framework, not public API.
*/
export class ToolInterruptError extends Error {
constructor(readonly metadata?: Record<string, any>) {
super();
}
}

/**
* Interrupts current tool execution causing tool request to be returned in the generation response.
* Should only be called within a tool.
*/
function interruptTool(metadata?: Record<string, any>) {
throw new ToolInterruptError(metadata);
}
4 changes: 2 additions & 2 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ import {
SessionError,
SessionOptions,
} from '@genkit-ai/ai/session';
import { resolveTools } from '@genkit-ai/ai/tool';
import { resolveTools, ToolFn } from '@genkit-ai/ai/tool';
import {
Action,
defineFlow,
Expand Down Expand Up @@ -218,7 +218,7 @@ export class Genkit implements HasRegistry {
*/
defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
fn: ToolFn<I, O>
): ToolAction<I, O> {
return defineTool(this.registry, config, fn);
}
Expand Down
Loading
Loading