Skip to content

Commit

Permalink
Merge pull request #40 from statelyai/davidkpiano/correlation-id
Browse files Browse the repository at this point in the history
Correlation ID
  • Loading branch information
davidkpiano authored Jul 21, 2024
2 parents e3b8cd9 + 8b7c374 commit 8cc48d0
Show file tree
Hide file tree
Showing 9 changed files with 336 additions and 195 deletions.
26 changes: 26 additions & 0 deletions .changeset/khaki-emus-jog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
'@statelyai/agent': minor
---

Correlation IDs are now provided as part of the result from `agent.generateText(…)` and `agent.streamText(…)`:

```ts
const result = await agent.generateText({
prompt: 'Write me a song',
correlationId: 'my-correlation-id',
// ...
});

result.correlationId; // 'my-correlation-id'
```

These correlation IDs can be passed to feedback:

```ts
// ...

agent.addFeedback({
reward: -1,
correlationId: result.correlationId,
});
```
11 changes: 11 additions & 0 deletions .changeset/silent-deers-visit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
'@statelyai/agent': minor
---

Changes to agent feedback (the `AgentFeedback` interface):

- `goal` is now optional
- `observationId` is now optional
- `correlationId` has been added (optional)
- `reward` has been added (optional)
- `attributes` are now optional
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"ts-node": "^10.9.2",
"tsup": "^8.2.0",
"typescript": "^5.5.3",
"vitest": "^1.6.0",
"vitest": "^2.0.3",
"wikipedia": "^2.1.2",
"zod": "^3.23.8"
},
Expand Down
235 changes: 76 additions & 159 deletions pnpm-lock.yaml

Large diffs are not rendered by default.

151 changes: 149 additions & 2 deletions src/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { test, expect, vi } from 'vitest';
import { createAgent, type AIAdapter } from './';
import {
AgentGenerateTextResult,
AgentMessage,
createAgent,
type AIAdapter,
} from './';
import { createActor, createMachine } from 'xstate';
import { GenerateTextResult } from 'ai';
import { z } from 'zod';
Expand Down Expand Up @@ -297,7 +302,7 @@ test('You can listen for plan events', async () => {
},
},
],
} as any as GenerateTextResult<any>;
} as any as AgentGenerateTextResult;
},
streamText: {} as any,
},
Expand Down Expand Up @@ -357,3 +362,145 @@ test('agent.types provides context and event types', () => {
// @ts-expect-error
agent.types.context satisfies { score: string };
});

test.each(['generateText', 'streamText'] as const)(
'can provide a correlation ID (%s)',
async (method) => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

opts.onFinish?.(res);

return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

const promise = new Promise<AgentMessage>((res) => {
agent.onMessage((msg) => {
if (msg.role === 'assistant') {
res(msg);
}
});
});

await agent[method]({
prompt: 'hi',
correlationId: 'c-1',
});

const msg = await promise;

expect(msg.correlationId).toBe('c-1');
expect(msg.parentCorrelationId).toBe(undefined);
}
);

test.each(['generateText', 'streamText'] as const)(
'correlation IDs are automatically generated if not provided (%s)',
async (method) => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

opts.onFinish?.(res);

return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

await agent[method]({
prompt: 'hi',
});

const messages = agent.getMessages();

expect(messages[0]?.correlationId).toEqual(expect.stringMatching(/.+/));
expect(messages[0]?.role).toBe('user');
expect(messages[1]?.correlationId).toEqual(expect.stringMatching(/.+/));
expect(messages[1]?.role).toBe('assistant');

expect(messages[0]!.correlationId).toEqual(messages[1]!.correlationId);
}
);

test.each(['generateText', 'streamText'] as const)(
'can provide a parent correlation ID (%s)',
async (method) => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

opts.onFinish?.(res);

return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

await agent[method]({
prompt: 'hi',
correlationId: 'c-1',
parentCorrelationId: 'c-0',
});

const msg = agent.getMessages().find((msg) => msg.role === 'assistant')!;

expect(msg.correlationId).toBe('c-1');
expect(msg.parentCorrelationId).toBe('c-0');
}
);

test.each(['generateText', 'streamText'] as const)(
'can add feedback to a correlation (%s)',
async (method) => {
const agent = createAgent({
name: 'test',
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

opts.onFinish?.(res);

return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

const res = await agent[method]({
prompt: 'test',
});

agent.addFeedback({
correlationId: res.correlationId,
reward: -1,
});

const message = agent.getMessages()[0]!;
const feedback = agent.getFeedback()[0]!;

expect(message.correlationId).toBeDefined();
expect(feedback.correlationId).toEqual(message.correlationId);
}
);
25 changes: 14 additions & 11 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
AgentMemoryContext,
AgentObservation,
ContextFromZodContextMapping,
AgentFeedback,
} from './types';
import { simplePlanner } from './planners/simplePlanner';
import { agentGenerateText, agentStreamText } from './text';
Expand Down Expand Up @@ -73,12 +74,13 @@ export const agentLogic: AgentLogic<AnyEventObject> = fromTransition(
}
return state;
},
{
feedback: [],
messages: [],
observations: [],
plans: [],
} as AgentMemoryContext
() =>
({
feedback: [],
messages: [],
observations: [],
plans: [],
} as AgentMemoryContext)
);

export function createAgent<
Expand Down Expand Up @@ -140,8 +142,6 @@ export function createAgent<
logic?: AgentLogic<TEvents>;
adapter?: AIAdapter;
} & GenerateTextOptions): Agent<TContext, TEvents> {
const messageHistoryListeners: Observer<AgentMessage>[] = [];

const agent = createActor(logic) as unknown as Agent<TContext, TEvents>;
agent.events = events;
agent.model = model;
Expand All @@ -155,7 +155,7 @@ export function createAgent<
agent.memory = getMemory ? getMemory(agent) : undefined;

agent.onMessage = (callback) => {
messageHistoryListeners.push(toObserver(callback));
agent.on('message', (ev) => callback(ev.message));
};

agent.decide = (opts) => {
Expand All @@ -168,7 +168,8 @@ export function createAgent<
id: messageInput.id ?? randomId(),
timestamp: messageInput.timestamp ?? Date.now(),
sessionId: agent.sessionId,
};
correlationId: messageInput.correlationId ?? randomId(),
} satisfies AgentMessage;
agent.send({
type: 'agent.message',
message,
Expand All @@ -185,9 +186,11 @@ export function createAgent<
agent.addFeedback = (feedbackInput) => {
const feedback = {
...feedbackInput,
attributes: { ...feedbackInput.attributes },
reward: feedbackInput.reward ?? 0,
timestamp: feedbackInput.timestamp ?? Date.now(),
sessionId: agent.sessionId,
};
} satisfies AgentFeedback;
agent.send({
type: 'agent.feedback',
feedback,
Expand Down
3 changes: 1 addition & 2 deletions src/planners/simplePlanner.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CoreTool, tool } from 'ai';
import { type CoreTool, tool } from 'ai';
import {
AgentPlan,
AgentPlanInput,
Expand Down Expand Up @@ -118,7 +118,6 @@ export async function simplePlanner<T extends AnyAgent>(
const singleResult = result.toolResults[0];

if (!singleResult) {
console.log(toolMap);
// TODO: retries?
console.warn('No tool call results returned');
return undefined;
Expand Down
36 changes: 26 additions & 10 deletions src/text.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import type {
CoreMessage,
CoreTool,
GenerateTextResult,
StreamTextResult,
} from 'ai';
import type { CoreMessage, CoreTool, GenerateTextResult } from 'ai';
import {
AgentGenerateTextOptions,
AgentGenerateTextResult,
AgentStreamTextOptions,
AgentStreamTextResult,
AnyAgent,
} from './types';
import { defaultTextTemplate } from './templates/defaultText';
Expand Down Expand Up @@ -51,11 +48,13 @@ export async function getMessages(
export async function agentGenerateText<T extends AnyAgent>(
agent: T,
options: AgentGenerateTextOptions
) {
): Promise<AgentGenerateTextResult> {
const resolvedOptions = {
...agent.defaultOptions,
...options,
correlationId: options.correlationId ?? randomId(),
};
// Generate a correlation ID if one is not provided
const template = resolvedOptions.template ?? defaultTextTemplate;
// TODO: check if messages was provided instead
const id = randomId();
Expand All @@ -76,6 +75,8 @@ export async function agentGenerateText<T extends AnyAgent>(
role: 'user',
content: promptWithContext,
timestamp: Date.now(),
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId,
});

const result = await agent.adapter.generateText({
Expand All @@ -91,18 +92,25 @@ export async function agentGenerateText<T extends AnyAgent>(
timestamp: Date.now(),
responseId: id,
result,
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId,
});

return result;
return {
...result,
parentCorrelationId: resolvedOptions.parentCorrelationId,
correlationId: resolvedOptions.correlationId,
};
}

export async function agentStreamText(
agent: AnyAgent,
options: AgentStreamTextOptions
): Promise<StreamTextResult<any>> {
): Promise<AgentStreamTextResult> {
const resolvedOptions = {
...agent.defaultOptions,
...options,
correlationId: options.correlationId ?? randomId(),
};
const template = resolvedOptions.template ?? defaultTextTemplate;

Expand All @@ -124,6 +132,8 @@ export async function agentStreamText(
content: promptWithContext,
id,
timestamp: Date.now(),
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId,
});

const result = await agent.adapter.streamText({
Expand All @@ -149,11 +159,17 @@ export async function agentStreamText(
id: randomId(),
timestamp: Date.now(),
responseId: id,
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId,
});
},
});

return result;
return {
...result,
parentCorrelationId: resolvedOptions.parentCorrelationId,
correlationId: resolvedOptions.correlationId,
} as unknown as AgentStreamTextResult; // TODO: fix
}

export function fromTextStream<T extends AnyAgent>(
Expand Down
Loading

0 comments on commit 8cc48d0

Please sign in to comment.