Skip to content

Commit

Permalink
Add changesets
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkpiano committed Jul 21, 2024
1 parent da9b0a7 commit 8b7c374
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 278 deletions.
26 changes: 26 additions & 0 deletions .changeset/khaki-emus-jog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
'@statelyai/agent': minor
---

Correlation IDs are now provided as part of the result from `agent.generateText(…)` and `agent.streamText(…)`:

```ts
const result = await agent.generateText({
prompt: 'Write me a song',
correlationId: 'my-correlation-id',
// ...
});

result.correlationId; // 'my-correlation-id'
```

These correlation IDs can be passed to feedback:

```ts
// ...

agent.addFeedback({
reward: -1,
correlationId: result.correlationId,
});
```
11 changes: 11 additions & 0 deletions .changeset/silent-deers-visit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
'@statelyai/agent': minor
---

Changes to agent feedback (the `AgentFeedback` interface):

- `goal` is now optional
- `observationId` is now optional
- `correlationId` has been added (optional)
- `reward` has been added (optional)
- `attributes` are now optional
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"ts-node": "^10.9.2",
"tsup": "^8.2.0",
"typescript": "^5.5.3",
"vitest": "^1.6.0",
"vitest": "^2.0.3",
"wikipedia": "^2.1.2",
"zod": "^3.23.8"
},
Expand Down
235 changes: 76 additions & 159 deletions pnpm-lock.yaml

Large diffs are not rendered by default.

224 changes: 120 additions & 104 deletions src/agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -363,128 +363,144 @@ test('agent.types provides context and event types', () => {
agent.types.context satisfies { score: string };
});

test('can provide a correlation ID', async () => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
generateText: async () => {
const res = {
text: 'response',
};
test.each(['generateText', 'streamText'] as const)(
'can provide a correlation ID (%s)',
async (method) => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

opts.onFinish?.(res);

return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

return res as AgentGenerateTextResult;
},
streamText: {} as any,
},
});
const promise = new Promise<AgentMessage>((res) => {
agent.onMessage((msg) => {
if (msg.role === 'assistant') {
res(msg);
}
});
});

const promise = new Promise<AgentMessage>((res) => {
agent.onMessage((msg) => {
if (msg.role === 'assistant') {
res(msg);
}
await agent[method]({
prompt: 'hi',
correlationId: 'c-1',
});
});

await agent.generateText({
prompt: 'hi',
correlationId: 'c-1',
});
const msg = await promise;

const msg = await promise;
expect(msg.correlationId).toBe('c-1');
expect(msg.parentCorrelationId).toBe(undefined);
}
);

expect(msg.correlationId).toBe('c-1');
expect(msg.parentCorrelationId).toBe(undefined);
});
test.each(['generateText', 'streamText'] as const)(
'correlation IDs are automatically generated if not provided (%s)',
async (method) => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

test('correlation IDs are automatically generated if not provided', async () => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
generateText: async () => {
const res = {
text: 'response',
};
opts.onFinish?.(res);

return res as AgentGenerateTextResult;
},
streamText: {} as any,
},
});
return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

await agent.generateText({
prompt: 'hi',
});
await agent[method]({
prompt: 'hi',
});

const messages = agent.getMessages();
const messages = agent.getMessages();

expect(messages[0]?.correlationId).toEqual(expect.stringMatching(/.+/));
expect(messages[0]?.role).toBe('user');
expect(messages[1]?.correlationId).toEqual(expect.stringMatching(/.+/));
expect(messages[1]?.role).toBe('assistant');
expect(messages[0]?.correlationId).toEqual(expect.stringMatching(/.+/));
expect(messages[0]?.role).toBe('user');
expect(messages[1]?.correlationId).toEqual(expect.stringMatching(/.+/));
expect(messages[1]?.role).toBe('assistant');

expect(messages[0]!.correlationId).toEqual(messages[1]!.correlationId);
});
expect(messages[0]!.correlationId).toEqual(messages[1]!.correlationId);
}
);

test('can provide a parent correlation ID', async () => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
generateText: async (o) => {
const res = {
text: 'response',
};

return res as AgentGenerateTextResult;
},
streamText: {} as any,
},
});

await agent.generateText({
prompt: 'hi',
correlationId: 'c-1',
parentCorrelationId: 'c-0',
});
test.each(['generateText', 'streamText'] as const)(
'can provide a parent correlation ID (%s)',
async (method) => {
const agent = createAgent({
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

const msg = agent.getMessages().find((msg) => msg.role === 'assistant')!;
opts.onFinish?.(res);

expect(msg.correlationId).toBe('c-1');
expect(msg.parentCorrelationId).toBe('c-0');
});
return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

test('can add feedback to a correlation', async () => {
const agent = createAgent({
name: 'test',
model: {} as any,
events: {},
adapter: {
generateText: async () => {
const res = {
text: 'response',
};
await agent[method]({
prompt: 'hi',
correlationId: 'c-1',
parentCorrelationId: 'c-0',
});

return res as AgentGenerateTextResult;
},
streamText: {} as any,
},
});
const msg = agent.getMessages().find((msg) => msg.role === 'assistant')!;

expect(msg.correlationId).toBe('c-1');
expect(msg.parentCorrelationId).toBe('c-0');
}
);

test.each(['generateText', 'streamText'] as const)(
'can add feedback to a correlation (%s)',
async (method) => {
const agent = createAgent({
name: 'test',
model: {} as any,
events: {},
adapter: {
[method]: async (opts) => {
const res = {
text: 'response',
};

opts.onFinish?.(res);

return res as AgentGenerateTextResult;
},
} as any as AIAdapter,
});

const res = await agent.generateText({
prompt: 'test',
});
const res = await agent[method]({
prompt: 'test',
});

agent.addFeedback({
correlationId: res.correlationId,
reward: -1,
});
agent.addFeedback({
correlationId: res.correlationId,
reward: -1,
});

const message = agent.getMessages()[0]!;
const feedback = agent.getFeedback()[0]!;
const message = agent.getMessages()[0]!;
const feedback = agent.getFeedback()[0]!;

expect(message.correlationId).toBeDefined();
expect(feedback.correlationId).toEqual(message.correlationId);
});
expect(message.correlationId).toBeDefined();
expect(feedback.correlationId).toEqual(message.correlationId);
}
);
15 changes: 7 additions & 8 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ export const agentLogic: AgentLogic<AnyEventObject> = fromTransition(
}
return state;
},
{
feedback: [],
messages: [],
observations: [],
plans: [],
} as AgentMemoryContext
() =>
({
feedback: [],
messages: [],
observations: [],
plans: [],
} as AgentMemoryContext)
);

export function createAgent<
Expand Down Expand Up @@ -141,8 +142,6 @@ export function createAgent<
logic?: AgentLogic<TEvents>;
adapter?: AIAdapter;
} & GenerateTextOptions): Agent<TContext, TEvents> {
const messageHistoryListeners: Observer<AgentMessage>[] = [];

const agent = createActor(logic) as unknown as Agent<TContext, TEvents>;
agent.events = events;
agent.model = model;
Expand Down
11 changes: 5 additions & 6 deletions src/text.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import type {
CoreMessage,
CoreTool,
GenerateTextResult,
StreamTextResult,
} from 'ai';
import type { CoreMessage, CoreTool, GenerateTextResult } from 'ai';
import {
AgentGenerateTextOptions,
AgentGenerateTextResult,
Expand Down Expand Up @@ -137,6 +132,8 @@ export async function agentStreamText(
content: promptWithContext,
id,
timestamp: Date.now(),
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId,
});

const result = await agent.adapter.streamText({
Expand All @@ -162,6 +159,8 @@ export async function agentStreamText(
id: randomId(),
timestamp: Date.now(),
responseId: id,
correlationId: resolvedOptions.correlationId,
parentCorrelationId: resolvedOptions.parentCorrelationId,
});
},
});
Expand Down

0 comments on commit 8b7c374

Please sign in to comment.