Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: --prompt --stream for Bedrock and noop engines. #109

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/args.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export function makeArgs(argv = process.argv) {
"template-view": { type: "string", default: process.env["AILLY_TEMPLATE_VIEW"] ? [process.env["AILLY_TEMPLATE_VIEW"]] : [], multiple: true },
prompt: { type: "string", default: process.env["AILLY_PROMPT"], short: "p" },
system: { type: "string", default: process.env["AILLY_SYSTEM"], short: "s" },
stream: { type: 'boolean', default: false },
"request-limit": { type: "string", default: process.env["AILLY_REQUEST_LIMIT"] },
"max-depth": { type: "string", default: "1" },
temperature: { type: "string", default: "" },
Expand Down Expand Up @@ -63,6 +64,7 @@ export function help() {
'none' includes no additional content (including no system context) when generating.
(note: context is separate from isolated. isolated: true with either 'content' or 'folder' will result in the same behavior with either. With 'none', Ailly will send _only_ the prompt when generating.)

--stream (--prompt only) print responses as they return.
-e, --edit use Ailly in edit mode. Provide a single file in paths, an edit marker, and a prompt. The path will be updated with the edit marker at the prompt.
-l, --lines the lines to edit as '[start]:[end]' with start inclusive, and end exclusive. With only '[start]', will insert after. With only ':[end]', will insert before.

Expand Down
29 changes: 23 additions & 6 deletions cli/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,22 @@ async function main() {
default:
LOGGER.info(`Starting ${loaded.content.length} requests`);
generator.start();
await generator.allSettled();

const doneSummary = generator.summary();
LOGGER.info(`All ${doneSummary.totalPrompts} requests finished`);
if (doneSummary.errors) {
LOGGER.warn(`Finished with errors`, { errors: doneSummary.errors });
if (!args.values.stream) {
await finish(generator);
}
if (last == "/dev/stdout") {
const prompt = loaded.context[last];
if (args.values.stream) {
// Lazy spin until the request starts
while (prompt.responseStream == undefined) {
await Promise.resolve();
}
for await (const word of prompt.responseStream) {
process.stdout.write(word);
}
await finish(generator);
}
console.debug(`Finished prompt, final meta`, { meta: prompt.meta });
if (prompt.meta?.debug?.finish == 'failed') {
console.error(prompt.meta.debug.error.message);
return;
Expand All @@ -75,6 +82,16 @@ async function main() {
}
}

async function finish(generator) {
await generator.allSettled();

const doneSummary = generator.summary();
LOGGER.info(`All ${doneSummary.totalPrompts} requests finished`);
if (doneSummary.errors) {
LOGGER.warn(`Finished with errors`, { errors: doneSummary.errors });
}
}

/**
*
* @param {import("@davidsouther/jiffies/lib/esm/fs").FileSystem} fs
Expand Down
229 changes: 229 additions & 0 deletions core/src/actions/prompt_thread.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { range } from "@davidsouther/jiffies/lib/esm/range.js";
import { PromptThread, generateOne, scheduler } from "./prompt_thread";
import { LOGGER } from "../util";
import { cleanState } from "@davidsouther/jiffies/lib/esm/scope/state";
import { loadContent } from "../content/content";
import {
FileSystem,
ObjectFileSystemAdapter,
} from "@davidsouther/jiffies/lib/esm/fs";
import { getPlugin, makePipelineSettings } from "../ailly";
import { getEngine } from "../engine";
import { LEVEL } from "@davidsouther/jiffies/lib/esm/log";
import { TIMEOUT } from "../engine/noop";

describe("scheduler", () => {
it("limits outstanding tasks", async () => {
const tasks = range(0, 5).map((i) => ({
i,
started: false,
finished: false,
...Promise.withResolvers<void>(),
}));
const runners = tasks.map((task) => async () => {
console.log(`starting ${task.i}`);
task.started = true;
await task.promise;
console.log(`finishing ${task.i}`);
task.finished = true;
});

scheduler(runners, 2);

expect(tasks[0].started).toBe(true);
expect(tasks[1].started).toBe(true);
expect(tasks[2].started).toBe(false);
expect(tasks[3].started).toBe(false);
expect(tasks[4].started).toBe(false);

await Promise.resolve().then(() => tasks[0].resolve());
expect(tasks[0].finished).toBe(true);
await Promise.resolve(); // Allow outstanding to clear
await Promise.resolve(); // Allow loop to continue

expect(tasks[1].started).toBe(true);
expect(tasks[2].started).toBe(true);
expect(tasks[3].started).toBe(false);
expect(tasks[4].started).toBe(false);
});
});

describe("generateOne", () => {
let level = LOGGER.level;
const state = cleanState(async () => {
const logger = {
info: vi.spyOn(LOGGER, "info"),
debug: vi.spyOn(LOGGER, "debug"),
};
LOGGER.level = LEVEL.SILENT;
const context = await loadContent(
new FileSystem(
new ObjectFileSystemAdapter({
"a.txt": `prompt a`,
"a.txt.ailly.md": `response a`,
"b.txt": `---\nprompt: prompt b\nskip: true\n---\nresponse b`,
"c.txt": "tell me a joke\n",
})
)
);
const engine = await getEngine("noop");
TIMEOUT.setTimeout(0);
expect(logger.debug).toHaveBeenCalledWith("Loading content from /");
expect(logger.debug).toHaveBeenCalledWith("Found 3 at or below /");
expect(logger.info).toHaveBeenCalledTimes(0);
logger.debug.mockClear();
logger.info.mockClear();
return { logger, context, engine };
}, beforeEach);

afterEach(() => {
vi.restoreAllMocks();
LOGGER.level = level;
TIMEOUT.resetTimeout();
});

it("skips some", async () => {
generateOne(
state.context["/a.txt"],
state.context,
await makePipelineSettings({ root: "/", overwrite: false }),
state.engine
);
expect(state.logger.info).toHaveBeenCalledWith("Skipping /a.txt");
state.logger.info.mockClear();

generateOne(
state.context["/b.txt"],
state.context,
await makePipelineSettings({ root: "/" }),
state.engine
);
expect(state.logger.info).toHaveBeenCalledWith("Skipping /b.txt");
state.logger.info.mockClear();
});

it("generates others", async () => {
const content = state.context["/c.txt"];
expect(content.response).toBeUndefined();
await generateOne(
content,
state.context,
await makePipelineSettings({ root: "/" }),
state.engine
);
expect(state.logger.info).toHaveBeenCalledWith("Preparing /c.txt");
expect(state.logger.info).toHaveBeenCalledWith("Calling noop");
expect(content.response).toMatch(/^noop response for c.txt:/);
});
});

describe("PromptThread", () => {
let level = LOGGER.level;
const state = cleanState(async () => {
const logger = {
info: vi.spyOn(LOGGER, "info"),
debug: vi.spyOn(LOGGER, "debug"),
};
LOGGER.level = LEVEL.SILENT;
const fs = new FileSystem(
new ObjectFileSystemAdapter({
"a.txt": `prompt a`,
"a.txt.ailly.md": `response a`,
"b.txt": `---\nprompt: prompt b\nskip: true\n---\nresponse b`,
"c.txt": "tell me a joke\n",
})
);
const engine = await getEngine("noop");
TIMEOUT.setTimeout(0);
return { logger, fs, engine };
}, beforeEach);

afterEach(() => {
vi.restoreAllMocks();
LOGGER.level = level;
TIMEOUT.resetTimeout();
});

it("runs isolated", async () => {
const settings = await makePipelineSettings({ root: "/", isolated: true });
const context = await loadContent(state.fs, [], { isolated: true });
state.logger.debug.mockClear();
state.logger.info.mockClear();
const content = [...Object.values(context)];
const plugin = await (
await getPlugin("none")
).default(state.engine, settings);
const thread = PromptThread.run(
content,
context,
settings,
state.engine,
plugin
);
expect(thread.isDone).toBe(false);
expect(thread.finished).toBe(0);
expect(thread.errors.length).toBe(0);

await Promise.resolve();
await Promise.resolve();
await Promise.resolve();
await Promise.resolve();
await Promise.resolve();
await Promise.resolve();
// Enough to get one resolved

expect(thread.isDone).toBe(false);
expect(thread.finished).toBe(1);
expect(thread.errors.length).toBe(0);

await thread.allSettled();

expect(thread.isDone).toBe(true);
expect(thread.finished).toBe(3);
expect(thread.errors.length).toBe(0);

expect(content[0].response).toEqual(
`noop response for a.txt:\n\nsystem: \nuser: prompt a\nassistant: response a\nprompt a`
);
expect(content[1].response).toBeUndefined();
expect(content[2].response).toEqual(
`noop response for c.txt:\n\nsystem: \nuser: tell me a joke\n\ntell me a joke\n`
);
});

it("runs sequence", async () => {
const settings = await makePipelineSettings({ root: "/" });
const context = await loadContent(state.fs);
state.logger.debug.mockClear();
state.logger.info.mockClear();
const content = [...Object.values(context)];
const plugin = await (
await getPlugin("none")
).default(state.engine, settings);
const thread = PromptThread.run(
content,
context,
settings,
state.engine,
plugin
);
expect(thread.isDone).toBe(false);
expect(thread.finished).toBe(0);
expect(thread.errors.length).toBe(0);

await thread.allSettled();

expect(thread.isDone).toBe(true);
expect(thread.finished).toBe(3);
expect(thread.errors.length).toBe(0);

expect(content[0].response).toEqual(
`noop response for a.txt:\n\nsystem: \nuser: prompt a\nassistant: response a\nprompt a`
);
expect(content[1].response).toBeUndefined();
expect(content[2].response).toEqual(
`noop response for c.txt:\n\nsystem: \nuser: prompt a\nassistant: noop response for a.txt:\n\nsystem: \nuser: prompt a\nassistant: response a\nprompt a\nuser: response b\nuser: tell me a joke\n\ntell me a joke\n`
);
});
});
35 changes: 25 additions & 10 deletions core/src/actions/prompt_thread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export async function scheduler<T>(
let finished: Array<Promise<T>> = [];
let outstanding = new Set<Promise<T>>();
while (taskQueue.length > 0) {
if (outstanding.size > limit) {
if (outstanding.size >= limit) {
// Wait for something in outstanding to finish
await Promise.race([...outstanding]);
} else {
Expand All @@ -47,7 +47,7 @@ export async function scheduler<T>(
export class PromptThread {
finished: number = 0;
isolated: boolean = false;
done: boolean = false;
private done: boolean = false;
runner?: Promise<PromiseSettledResult<Content>[]>;
// Results holds a list of errors that occurred and the index the occurred at.
// If the thread is isolated, this can have many entries. If the thread is not
Expand Down Expand Up @@ -87,7 +87,7 @@ export class PromptThread {
private plugin: Plugin
) {
this.content = content;
this.isolated = Boolean(content[0]?.meta?.isolated ?? false);
this.isolated = Boolean(settings.isolated ?? false);
}

start() {
Expand Down Expand Up @@ -168,17 +168,20 @@ export class PromptThread {
}
}

async function generateOne(
export function generateOne(
c: Content,
context: Record<string, Content>,
settings: PipelineSettings,
engine: Engine
): Promise<Content> {
): Promise<void> {
const has_response = (c.response?.length ?? 0) > 0;

if (c.meta?.skip || (!settings.overwrite && has_response)) {
LOGGER.info(`Skipping ${c.path}`);
return c;
const stream = new TextEncoderStream();
stream.writable.getWriter().write(c.response ?? "");
c.responseStream = stream.readable;
return Promise.resolve();
}

LOGGER.info(`Preparing ${c.path}`);
Expand All @@ -198,15 +201,27 @@ async function generateOne(
// Skip the last `assistant` message
.filter((m, i, a) => !(m.role == "assistant" && i === a.length - 1)),
});
const generated = await engine.generate(c, settings);
c.response = generated.message;
c.meta = {
...c.meta,
debug: {
...generated.debug,
engine: settings.engine,
model: settings.model,
},
};
return c;
try {
const generator = engine.generate(c, settings);
c.responseStream = generator.stream;
return generator.done.then(
() => {
c.response = generator.message();
c.meta!.debug = { ...c.meta!.debug, ...generator.debug() };
},
(err) => {
c.meta!.debug = { ...c.meta!.debug, ...generator.debug() };
}
);
} catch (err) {
LOGGER.error(`Uncaught error in ${engine.name} generator`, { err });
return Promise.resolve();
}
}
Loading