diff --git a/cli/args.js b/cli/args.js index 2e57171..05549a5 100644 --- a/cli/args.js +++ b/cli/args.js @@ -20,6 +20,7 @@ export function makeArgs(argv = process.argv) { "template-view": { type: "string", default: process.env["AILLY_TEMPLATE_VIEW"] ? [process.env["AILLY_TEMPLATE_VIEW"]] : [], multiple: true }, prompt: { type: "string", default: process.env["AILLY_PROMPT"], short: "p" }, system: { type: "string", default: process.env["AILLY_SYSTEM"], short: "s" }, + stream: { type: 'boolean', default: false }, "request-limit": { type: "string", default: process.env["AILLY_REQUEST_LIMIT"] }, "max-depth": { type: "string", default: "1" }, temperature: { type: "string", default: "" }, @@ -63,6 +64,7 @@ export function help() { 'none' includes no additional content (including no system context) when generating. (note: context is separate from isolated. isolated: true with either 'content' or 'folder' will result in the same behavior with either. With 'none', Ailly will send _only_ the prompt when generating.) + --stream (--prompt only) print responses as they return. -e, --edit use Ailly in edit mode. Provide a single file in paths, an edit marker, and a prompt. The path will be updated with the edit marker at the prompt. -l, --lines the lines to edit as '[start]:[end]' with start inclusive, and end exclusive. With only '[start]', will insert after. With only ':[end]', will insert before. diff --git a/cli/index.js b/cli/index.js index d82fe21..382261c 100755 --- a/cli/index.js +++ b/cli/index.js @@ -49,15 +49,22 @@ async function main() { default: LOGGER.info(`Starting ${loaded.content.length} requests`); generator.start(); - await generator.allSettled(); - - const doneSummary = generator.summary(); - LOGGER.info(`All ${doneSummary.totalPrompts} requests finished`); - if (doneSummary.errors) { - LOGGER.warn(`Finished with errors`, { errors: doneSummary.errors }); + if (!args.values.stream) { + await finish(generator); } if (last == "/dev/stdout") { const prompt = loaded.context[last]; + if (args.values.stream) { + // Lazy spin until the request starts + while (prompt.responseStream == undefined) { + await Promise.resolve(); + } + for await (const word of prompt.responseStream) { + process.stdout.write(word); + } + await finish(generator); + } + console.debug(`Finished prompt, final meta`, { meta: prompt.meta }); if (prompt.meta?.debug?.finish == 'failed') { console.error(prompt.meta.debug.error.message); return; @@ -75,6 +82,16 @@ async function main() { } } +async function finish(generator) { + await generator.allSettled(); + + const doneSummary = generator.summary(); + LOGGER.info(`All ${doneSummary.totalPrompts} requests finished`); + if (doneSummary.errors) { + LOGGER.warn(`Finished with errors`, { errors: doneSummary.errors }); + } +} + /** * * @param {import("@davidsouther/jiffies/lib/esm/fs").FileSystem} fs diff --git a/core/src/actions/prompt_thread.test.ts b/core/src/actions/prompt_thread.test.ts new file mode 100644 index 0000000..537879c --- /dev/null +++ b/core/src/actions/prompt_thread.test.ts @@ -0,0 +1,229 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { range } from "@davidsouther/jiffies/lib/esm/range.js"; +import { PromptThread, generateOne, scheduler } from "./prompt_thread"; +import { LOGGER } from "../util"; +import { cleanState } from "@davidsouther/jiffies/lib/esm/scope/state"; +import { loadContent } from "../content/content"; +import { + FileSystem, + ObjectFileSystemAdapter, +} from "@davidsouther/jiffies/lib/esm/fs"; +import { getPlugin, makePipelineSettings } from "../ailly"; +import { getEngine } from "../engine"; +import { LEVEL } from "@davidsouther/jiffies/lib/esm/log"; +import { TIMEOUT } from "../engine/noop"; + +describe("scheduler", () => { + it("limits outstanding tasks", async () => { + const tasks = range(0, 5).map((i) => ({ + i, + started: false, + finished: false, + ...Promise.withResolvers(), + })); + const runners = tasks.map((task) => async () => { + console.log(`starting ${task.i}`); + task.started = true; + await task.promise; + console.log(`finishing ${task.i}`); + task.finished = true; + }); + + scheduler(runners, 2); + + expect(tasks[0].started).toBe(true); + expect(tasks[1].started).toBe(true); + expect(tasks[2].started).toBe(false); + expect(tasks[3].started).toBe(false); + expect(tasks[4].started).toBe(false); + + await Promise.resolve().then(() => tasks[0].resolve()); + expect(tasks[0].finished).toBe(true); + await Promise.resolve(); // Allow outstanding to clear + await Promise.resolve(); // Allow loop to continue + + expect(tasks[1].started).toBe(true); + expect(tasks[2].started).toBe(true); + expect(tasks[3].started).toBe(false); + expect(tasks[4].started).toBe(false); + }); +}); + +describe("generateOne", () => { + let level = LOGGER.level; + const state = cleanState(async () => { + const logger = { + info: vi.spyOn(LOGGER, "info"), + debug: vi.spyOn(LOGGER, "debug"), + }; + LOGGER.level = LEVEL.SILENT; + const context = await loadContent( + new FileSystem( + new ObjectFileSystemAdapter({ + "a.txt": `prompt a`, + "a.txt.ailly.md": `response a`, + "b.txt": `---\nprompt: prompt b\nskip: true\n---\nresponse b`, + "c.txt": "tell me a joke\n", + }) + ) + ); + const engine = await getEngine("noop"); + TIMEOUT.setTimeout(0); + expect(logger.debug).toHaveBeenCalledWith("Loading content from /"); + expect(logger.debug).toHaveBeenCalledWith("Found 3 at or below /"); + expect(logger.info).toHaveBeenCalledTimes(0); + logger.debug.mockClear(); + logger.info.mockClear(); + return { logger, context, engine }; + }, beforeEach); + + afterEach(() => { + vi.restoreAllMocks(); + LOGGER.level = level; + TIMEOUT.resetTimeout(); + }); + + it("skips some", async () => { + generateOne( + state.context["/a.txt"], + state.context, + await makePipelineSettings({ root: "/", overwrite: false }), + state.engine + ); + expect(state.logger.info).toHaveBeenCalledWith("Skipping /a.txt"); + state.logger.info.mockClear(); + + generateOne( + state.context["/b.txt"], + state.context, + await makePipelineSettings({ root: "/" }), + state.engine + ); + expect(state.logger.info).toHaveBeenCalledWith("Skipping /b.txt"); + state.logger.info.mockClear(); + }); + + it("generates others", async () => { + const content = state.context["/c.txt"]; + expect(content.response).toBeUndefined(); + await generateOne( + content, + state.context, + await makePipelineSettings({ root: "/" }), + state.engine + ); + expect(state.logger.info).toHaveBeenCalledWith("Preparing /c.txt"); + expect(state.logger.info).toHaveBeenCalledWith("Calling noop"); + expect(content.response).toMatch(/^noop response for c.txt:/); + }); +}); + +describe("PromptThread", () => { + let level = LOGGER.level; + const state = cleanState(async () => { + const logger = { + info: vi.spyOn(LOGGER, "info"), + debug: vi.spyOn(LOGGER, "debug"), + }; + LOGGER.level = LEVEL.SILENT; + const fs = new FileSystem( + new ObjectFileSystemAdapter({ + "a.txt": `prompt a`, + "a.txt.ailly.md": `response a`, + "b.txt": `---\nprompt: prompt b\nskip: true\n---\nresponse b`, + "c.txt": "tell me a joke\n", + }) + ); + const engine = await getEngine("noop"); + TIMEOUT.setTimeout(0); + return { logger, fs, engine }; + }, beforeEach); + + afterEach(() => { + vi.restoreAllMocks(); + LOGGER.level = level; + TIMEOUT.resetTimeout(); + }); + + it("runs isolated", async () => { + const settings = await makePipelineSettings({ root: "/", isolated: true }); + const context = await loadContent(state.fs, [], { isolated: true }); + state.logger.debug.mockClear(); + state.logger.info.mockClear(); + const content = [...Object.values(context)]; + const plugin = await ( + await getPlugin("none") + ).default(state.engine, settings); + const thread = PromptThread.run( + content, + context, + settings, + state.engine, + plugin + ); + expect(thread.isDone).toBe(false); + expect(thread.finished).toBe(0); + expect(thread.errors.length).toBe(0); + + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + // Enough to get one resolved + + expect(thread.isDone).toBe(false); + expect(thread.finished).toBe(1); + expect(thread.errors.length).toBe(0); + + await thread.allSettled(); + + expect(thread.isDone).toBe(true); + expect(thread.finished).toBe(3); + expect(thread.errors.length).toBe(0); + + expect(content[0].response).toEqual( + `noop response for a.txt:\n\nsystem: \nuser: prompt a\nassistant: response a\nprompt a` + ); + expect(content[1].response).toBeUndefined(); + expect(content[2].response).toEqual( + `noop response for c.txt:\n\nsystem: \nuser: tell me a joke\n\ntell me a joke\n` + ); + }); + + it("runs sequence", async () => { + const settings = await makePipelineSettings({ root: "/" }); + const context = await loadContent(state.fs); + state.logger.debug.mockClear(); + state.logger.info.mockClear(); + const content = [...Object.values(context)]; + const plugin = await ( + await getPlugin("none") + ).default(state.engine, settings); + const thread = PromptThread.run( + content, + context, + settings, + state.engine, + plugin + ); + expect(thread.isDone).toBe(false); + expect(thread.finished).toBe(0); + expect(thread.errors.length).toBe(0); + + await thread.allSettled(); + + expect(thread.isDone).toBe(true); + expect(thread.finished).toBe(3); + expect(thread.errors.length).toBe(0); + + expect(content[0].response).toEqual( + `noop response for a.txt:\n\nsystem: \nuser: prompt a\nassistant: response a\nprompt a` + ); + expect(content[1].response).toBeUndefined(); + expect(content[2].response).toEqual( + `noop response for c.txt:\n\nsystem: \nuser: prompt a\nassistant: noop response for a.txt:\n\nsystem: \nuser: prompt a\nassistant: response a\nprompt a\nuser: response b\nuser: tell me a joke\n\ntell me a joke\n` + ); + }); +}); diff --git a/core/src/actions/prompt_thread.ts b/core/src/actions/prompt_thread.ts index 36521fb..d846e0f 100644 --- a/core/src/actions/prompt_thread.ts +++ b/core/src/actions/prompt_thread.ts @@ -28,7 +28,7 @@ export async function scheduler( let finished: Array> = []; let outstanding = new Set>(); while (taskQueue.length > 0) { - if (outstanding.size > limit) { + if (outstanding.size >= limit) { // Wait for something in outstanding to finish await Promise.race([...outstanding]); } else { @@ -47,7 +47,7 @@ export async function scheduler( export class PromptThread { finished: number = 0; isolated: boolean = false; - done: boolean = false; + private done: boolean = false; runner?: Promise[]>; // Results holds a list of errors that occurred and the index the occurred at. // If the thread is isolated, this can have many entries. If the thread is not @@ -87,7 +87,7 @@ export class PromptThread { private plugin: Plugin ) { this.content = content; - this.isolated = Boolean(content[0]?.meta?.isolated ?? false); + this.isolated = Boolean(settings.isolated ?? false); } start() { @@ -168,17 +168,20 @@ export class PromptThread { } } -async function generateOne( +export function generateOne( c: Content, context: Record, settings: PipelineSettings, engine: Engine -): Promise { +): Promise { const has_response = (c.response?.length ?? 0) > 0; if (c.meta?.skip || (!settings.overwrite && has_response)) { LOGGER.info(`Skipping ${c.path}`); - return c; + const stream = new TextEncoderStream(); + stream.writable.getWriter().write(c.response ?? ""); + c.responseStream = stream.readable; + return Promise.resolve(); } LOGGER.info(`Preparing ${c.path}`); @@ -198,15 +201,27 @@ async function generateOne( // Skip the last `assistant` message .filter((m, i, a) => !(m.role == "assistant" && i === a.length - 1)), }); - const generated = await engine.generate(c, settings); - c.response = generated.message; c.meta = { ...c.meta, debug: { - ...generated.debug, engine: settings.engine, model: settings.model, }, }; - return c; + try { + const generator = engine.generate(c, settings); + c.responseStream = generator.stream; + return generator.done.then( + () => { + c.response = generator.message(); + c.meta!.debug = { ...c.meta!.debug, ...generator.debug() }; + }, + (err) => { + c.meta!.debug = { ...c.meta!.debug, ...generator.debug() }; + } + ); + } catch (err) { + LOGGER.error(`Uncaught error in ${engine.name} generator`, { err }); + return Promise.resolve(); + } } diff --git a/core/src/content/content.test.ts b/core/src/content/content.test.ts index 9be92a9..26b8ee3 100644 --- a/core/src/content/content.test.ts +++ b/core/src/content/content.test.ts @@ -39,7 +39,7 @@ test("it loads content", async () => { path: "/01_start.md", outPath: "/01_start.md.ailly.md", prompt: "The quick brown", - response: "", + response: undefined, context: { system: [], view: {}, @@ -56,7 +56,7 @@ test("it loads content", async () => { path: "/20b/40_part.md", outPath: "/20b/40_part.md.ailly.md", prompt: "fox jumped", - response: "", + response: undefined, context: { system: [], view: {}, @@ -73,7 +73,7 @@ test("it loads content", async () => { path: "/20b/56_part.md", outPath: "/20b/56_part.md.ailly.md", prompt: "over the lazy", - response: "", + response: undefined, context: { system: [], view: {}, @@ -91,7 +91,7 @@ test("it loads content", async () => { path: "/54_a/12_section.md", outPath: "/54_a/12_section.md.ailly.md", prompt: "dog.", - response: "", + response: undefined, context: { system: [], view: {}, @@ -156,7 +156,7 @@ test("it loads combined prompt and responses", async () => { path: "/prompt.md", outPath: "/prompt.md", prompt: "prompt", - response: "", + response: undefined, context: { system: [], view: {}, @@ -233,7 +233,7 @@ test("it loads separate prompt and responses", async () => { path: "/prompt.md", outPath: "/prompt.md.ailly.md", prompt: "prompt", - response: "", + response: undefined, context: { system: [], view: {}, @@ -292,7 +292,7 @@ test("it loads separate prompt and responses in different out directors", async path: "/root/prompt.md", outPath: "/out/prompt.md.ailly.md", prompt: "prompt", - response: "", + response: undefined, context: { system: [], view: {}, @@ -407,7 +407,7 @@ test("it writes deep java prompts and responses", async () => { path: "/root/src/com/example/Main.java", outPath: "/out/src/com/example/Main.java.ailly.md", prompt: "class Main {}\n", - response: "", + response: undefined, context: { system: [], view: {}, @@ -668,7 +668,7 @@ describe("Load aillyrc", () => { path: "/root/b", outPath: "/root/b.ailly.md", prompt: "b", - response: "", + response: undefined, context: { view: {}, }, diff --git a/core/src/content/content.ts b/core/src/content/content.ts index f0f885b..de77037 100644 --- a/core/src/content/content.ts +++ b/core/src/content/content.ts @@ -30,6 +30,7 @@ export interface Content { // The prompt itself prompt: string; response?: string; + responseStream?: ReadableStream; context: Context; meta?: ContentMeta; } @@ -57,7 +58,7 @@ export interface ContentMeta { skip?: boolean; isolated?: boolean; combined?: boolean; - debug?: unknown; + debug?: {}; view?: false | View; prompt?: string; temperature?: number; @@ -93,7 +94,6 @@ function partitionDirectory(stats: Stats[]): PartitionedDirectory { async function loadDir(fs: FileSystem): Promise { const dir = await fs.readdir("."); const entries = await Promise.all(dir.map((s) => fs.stat(s))); - // const entries = await fs.scandir(""); return partitionDirectory(entries); } @@ -143,7 +143,7 @@ async function loadFile( ); } - let response = ""; + let response: string | undefined; let outPath: string; if (data.prompt) { outPath = promptPath; @@ -171,6 +171,7 @@ async function loadFile( } } } + if (response?.trim() == "") response = undefined; const view = data.view === false ? false : data.view ?? {}; delete data.view; diff --git a/core/src/engine/bedrock/bedrock.ts b/core/src/engine/bedrock/bedrock.ts index 8659824..11c5797 100644 --- a/core/src/engine/bedrock/bedrock.ts +++ b/core/src/engine/bedrock/bedrock.ts @@ -1,10 +1,11 @@ import { BedrockRuntimeClient, InvokeModelCommand, + InvokeModelWithResponseStreamCommand, } from "@aws-sdk/client-bedrock-runtime"; import { Content, View } from "../../content/content.js"; import { LOGGER as ROOT_LOGGER } from "../../util.js"; -import { Summary } from "../index.js"; +import { EngineGenerate, Summary } from "../index.js"; import { Models, PromptBuilder } from "./prompt-builder.js"; import { getLogger } from "@davidsouther/jiffies/lib/esm/log.js"; import { fromNodeProviderChain } from "@aws-sdk/credential-providers"; @@ -15,16 +16,28 @@ export const DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"; const LOGGER = getLogger("@ailly/core:bedrock"); +export interface BedrockDebug { + statistics?: { + inputTokenCount?: number; + outputTokenCount?: number; + invocationLatency?: number; + firstByteLatency?: number; + }; + finish?: string; + error?: Error; + id: string; +} + const MODEL_MAP: Record = { sonnet: "anthropic.claude-3-sonnet-20240229-v1:0", haiku: "anthropic.claude-3-haiku-20240307-v1:0", opus: "anthropic.claude-3-opus-20240229-v1:0", }; -export async function generate( +export const generate: EngineGenerate = ( c: Content, { model = DEFAULT_MODEL }: { model?: string } -): Promise<{ message: string; debug: unknown }> { +) => { LOGGER.level = ROOT_LOGGER.level; LOGGER.format = ROOT_LOGGER.format; const bedrock = new BedrockRuntimeClient({ @@ -59,41 +72,84 @@ export async function generate( }); try { - const response = await bedrock.send( - new InvokeModelCommand({ - modelId: model, - contentType: "application/json", - accept: "application/json", - body: JSON.stringify(prompt), + let message = ""; + const debug: BedrockDebug = { id: "", finish: "unknown" }; + const stream = new TransformStream(); + const writer = stream.writable.getWriter(); + const done = bedrock + .send( + new InvokeModelWithResponseStreamCommand({ + modelId: model, + contentType: "application/json", + accept: "application/json", + body: JSON.stringify(prompt), + }) + ) + .then(async (response) => { + LOGGER.info(`Begin streaming response from Bedrock for ${c.name}`); + + for await (const block of response.body ?? []) { + const chunk = JSON.parse( + new TextDecoder().decode(block.chunk?.bytes) + ); + LOGGER.debug( + `Received chunk for (${ + chunk.message?.id ?? debug.id + }) from Bedrock for ${c.name}`, + { chunk } + ); + switch (chunk.type) { + case "message_start": + debug.id = chunk.message.id; + break; + case "content_block_start": + break; + case "content_block_delta": + const text = chunk.delta.text; + await writer.ready; + message += text; + await writer.write(text); + break; + case "message_delta": + debug.finish = chunk.delta.stop_reason; + break; + case "message_stop": + debug.statistics = chunk["amazon-bedrock-invocationMetrics"]; + break; + } + } }) - ); - - const body = JSON.parse(response.body.transformToString()); - response.body = body; - - LOGGER.info(`Response from Bedrock for ${c.name}`); - LOGGER.debug(`Bedrock response`, body); - - let message: string = (body.content?.[0]?.text ?? "").trim(); - // In edit mode, claude (at least) does not return the stop sequence nor the prefill, so the edit is the message. + .catch((e) => { + debug.error = e as Error; + LOGGER.error(`Error for bedrock response ${debug.id}`, { + error: debug.error, + }); + }) + .finally(async () => { + LOGGER.debug(`Closing write stream for bedrock response ${debug.id}`); + await writer.close(); + }); return { - message, - debug: { - id: null, - model, - usage: null, - finish: body.stop_reason, - }, + stream: stream.readable, + message: () => (debug.error ? "💩" : message), + debug: () => debug, + done, }; } catch (error) { LOGGER.warn(`Error from Bedrock for ${c.name}`, { error }); return { - message: "💩", - debug: { finish: "failed", error: { message: (error as Error).message } }, + stream: new TextDecoderStream("💩").readable, + message: () => "💩", + debug: () => ({ + finish: "failed", + error: error as Error, + id: "_failed_", + }), + done: Promise.resolve(), }; } -} +}; export async function view(): Promise { return { diff --git a/core/src/engine/index.ts b/core/src/engine/index.ts index b19620f..b617ba4 100644 --- a/core/src/engine/index.ts +++ b/core/src/engine/index.ts @@ -5,14 +5,22 @@ import * as mistral from "./mistral/mistral.js"; import * as noop from "./noop.js"; import { PipelineSettings } from "../ailly.js"; +export type EngineGenerate = + ( + c: Content, + parameters: PipelineSettings + ) => { + stream: ReadableStream; + message(): string; + debug(): D; + done: Promise; + }; + export interface Engine { DEFAULT_MODEL: string; name: string; format(c: Content[], context: Record): Promise; - generate( - c: Content, - parameters: PipelineSettings - ): Promise<{ debug: D; message: string }>; + generate: EngineGenerate; vector(s: string, parameters: ContentMeta): Promise; view?(): Promise; models?(): string[]; diff --git a/core/src/engine/mistral/mistral.ts b/core/src/engine/mistral/mistral.ts index 65ad57b..3cdc753 100644 --- a/core/src/engine/mistral/mistral.ts +++ b/core/src/engine/mistral/mistral.ts @@ -1,45 +1,63 @@ +import { EngineGenerate } from ".."; import type { Content } from "../../content/content.js"; import * as openai from "../openai.js"; import { spawn } from "node:child_process"; import { normalize, join, dirname } from "node:path"; -const MODEL = "mistralai/Mistral-7B-v0.1"; - -export async function generate( - c: Content, - {}: {} -): Promise<{ message: string; debug: unknown }> { - return new Promise<{ message: string; debug: unknown }>((resolve, reject) => { - const prompt = c.meta?.messages?.map(({ content }) => content).join("\n"); - if (!prompt) { - return reject("No messages in Content"); - } - - let cwd = dirname( - (import.meta?.url.replace(/^file:/, "") ?? __filename).replace( - "ailly/core/dist", - "ailly/core/src" - ) - ); - let command = join(cwd, normalize(".venv/bin/python3")); - let args = [join(cwd, "mistral.py"), prompt]; - let child = spawn(command, args, { cwd }); - - let response = ""; - child.on("message", (m) => (response += `${m}`)); - - const done = () => { - resolve({ message: response, debug: {} }); - }; - child.on("exit", done); - child.on("close", done); - child.on("disconnect", done); - - const error = (cause: unknown) => - reject(new Error("child_process had a problem" /*, { cause }*/)); - child.on("error", error); +const DEFAULT_MODEL = "mistralai/Mistral-7B-v0.1"; +interface MistralDebug {} + +export const generate: EngineGenerate = (c: Content, _) => { + const prompt = c.meta?.messages?.map(({ content }) => content).join("\n"); + if (!prompt) { + throw new Error("No messages in Content"); + } + + let cwd = dirname( + (import.meta?.url.replace(/^file:/, "") ?? __filename).replace( + "ailly/core/dist", + "ailly/core/src" + ) + ); + let command = join(cwd, normalize(".venv/bin/python3")); + let args = [join(cwd, "mistral.py"), prompt]; + let child = spawn(command, args, { cwd }); + + const stream = new TransformStream(); + const done = Promise.withResolvers(); + + let message = ""; + child.on("message", async (m) => { + const writer = await stream.writable.getWriter(); + await writer.ready; + await writer.write(m); + writer.releaseLock(); + message += `${m}`; }); -} + + const onDone = () => { + stream.writable.close(); + done.resolve(); + }; + child.on("exit", onDone); + child.on("close", onDone); + child.on("disconnect", onDone); + + const onError = (cause: unknown) => { + stream.writable.abort( + `child_process had a problem ${JSON.stringify(cause)}` + ); + done.reject(cause); + }; + child.on("error", onError); + + return { + stream: stream.readable, + message: () => message, + debug: () => ({}), + done: done.promise, + }; +}; export const format = openai.format; export const getMessages = openai.getMessages; @@ -48,7 +66,7 @@ export async function tune( content: Content[], context: Record, { - model = MODEL, + model = DEFAULT_MODEL, apiKey = process.env["OPENAI_API_KEY"] ?? "", baseURL = "http://localhost:8000/v1", }: { model: string; apiKey: string; baseURL: string } diff --git a/core/src/engine/noop.ts b/core/src/engine/noop.ts index 6fc4aaf..8c62cce 100644 --- a/core/src/engine/noop.ts +++ b/core/src/engine/noop.ts @@ -1,22 +1,24 @@ import { getLogger } from "@davidsouther/jiffies/lib/esm/log.js"; import { Content } from "../content/content.js"; import { LOGGER as ROOT_LOGGER } from "../util.js"; -import type { PipelineSettings } from "../ailly.js"; -import type { Message } from "./index.js"; import { addContentMessages } from "./messages.js"; +import { EngineGenerate } from "."; const LOGGER = getLogger("@ailly/core:noop"); -const asMessages = (content: Content) => [ - { role: "user", content: content.prompt } satisfies Message, - ...(content.response - ? [{ role: "assistant", content: content.response } satisfies Message] - : []), -]; - export const DEFAULT_MODEL = "NOOP"; -const NOOP_TIMEOUT = Number(process.env["AILLY_NOOP_TIMEOUT"] ?? 750); +export const TIMEOUT = { + timeout: 0, + setTimeout(timeout: number) { + TIMEOUT.timeout = timeout; + }, + resetTimeout() { + TIMEOUT.setTimeout(Number(process.env["AILLY_NOOP_TIMEOUT"] ?? 750)); + }, +}; +TIMEOUT.resetTimeout(); export const name = "noop"; + export async function format( contents: Content[], context: Record @@ -25,31 +27,63 @@ export async function format( addContentMessages(content, context); } } -export async function generate( - content: Content, - _: PipelineSettings -): Promise<{ debug: D; message: string }> { + +export const generate: EngineGenerate = (content: Content, _) => { LOGGER.level = ROOT_LOGGER.level; LOGGER.format = ROOT_LOGGER.format; - await new Promise((resolve) => { - setTimeout(() => resolve(), NOOP_TIMEOUT); - }); + const system = content.context.system?.map((s) => s.content).join("\n"); const messages = content.meta?.messages ?.map((m) => `${m.role}: ${m.content}`) .join("\n"); + const message = + process.env["AILLY_NOOP_RESPONSE"] ?? + [ + `noop response for ${content.name}:`, + system, + messages, + content.prompt, + ].join("\n"); + + let error: Error | undefined; + const stream = new TextEncoderStream(); + const done = Promise.resolve() + .then(async () => { + await sleep(TIMEOUT.timeout); + const writer = await stream.writable.getWriter(); + try { + await writer.ready; + if (process.env["AILLY_NOOP_STREAM"]) { + let first = true; + for (const word of message.split(" ")) { + await writer.write((first ? "" : " ") + word); + first = false; + await sleep(TIMEOUT.timeout / 10); + } + } + } finally { + writer.close(); + } + }) + .catch((err) => { + error = err as Error; + }); + return { - message: - process.env["AILLY_NOOP_RESPONSE"] ?? - [ - `noop response for ${content.name}:`, - system, - messages, - content.prompt, - ].join("\n"), - debug: { system: content.context.system } as unknown as D, + stream: stream.readable, + message: () => message, + debug: () => (error ? { finish: "failed", error } : {}), + done, }; +}; + +function sleep(duration: number) { + if (isFinite(duration) && duration > 16) + return new Promise((resolve) => { + setTimeout(() => resolve(), duration); + }); } + export async function vector(s: string, _: unknown): Promise { return [0.0, 1.0]; } diff --git a/core/src/engine/openai.ts b/core/src/engine/openai.ts index f204dbd..865d065 100644 --- a/core/src/engine/openai.ts +++ b/core/src/engine/openai.ts @@ -2,9 +2,11 @@ import { OpenAI, toFile } from "openai"; import { assertExists } from "@davidsouther/jiffies/lib/esm/assert.js"; import type { Content } from "../content/content.js"; import type { PipelineSettings } from "../ailly.js"; -import type { Message, Summary } from "./index.js"; +import type { EngineGenerate, Message, Summary } from "./index.js"; import { LOGGER, isDefined } from "../util.js"; import { encode } from "../encoding.js"; +import { ChatCompletionChunk } from "openai/resources/index.mjs"; +import { Stream } from "openai/streaming.mjs"; export const name = "openai"; @@ -15,10 +17,14 @@ const MODEL = "gpt-4-0613"; // const MODEL = "gpt-3.5-turbo-16k-0613"; const EMBEDDING_MODEL = "text-embedding-ada-002"; -export async function generate( +export interface OpenAIDebug { + finish?: string; + error?: Error; +} +export const generate: EngineGenerate = ( c: Content, { model = MODEL }: PipelineSettings -): Promise<{ message: string; debug: unknown }> { +) => { const apiKey = assertExists( process.env["OPENAI_API_KEY"], "Missing OPENAI_API_KEY" @@ -43,40 +49,68 @@ export async function generate( }; try { - const completions = await callOpenAiWithRateLimit(openai, body); - if (!completions) { - throw new Error( - "Failed to get completions and call with rate limit did not itself error" - ); - } + let message = ""; + let chunkNum = 0; + const stream = new TransformStream(); + + const done = Promise.resolve().then(async () => { + const completions = await callOpenAiWithRateLimit(openai, { + ...body, + stream: true, + }); + if (!completions) { + throw new Error( + "Failed to get completions and call with rate limit did not itself error" + ); + } + LOGGER.info(`Begin streaming response from Bedrock for ${c.name}`); + + for await (const block of completions) { + LOGGER.debug(`Received chunk ${chunkNum++} from Bedrock for ${c.name}`); + const writer = stream.writable.getWriter(); + await writer.ready; + const chunk = block.choices[0]?.delta.content; + message += chunk; + await writer.write(chunk); + writer.releaseLock(); + } + + await stream.writable.getWriter().close(); + }); - const choice = completions.choices[0]; LOGGER.debug(`Response from OpenAI for ${c.name}`, { - id: completions.id, - finish_reason: choice.finish_reason, + // id: completions.id, + // finish_reason: choice.finish_reason, }); return { - message: choice.message.content ?? "", - debug: { - id: completions.id, - model: completions.model, - usage: completions.usage, - finish: choice.finish_reason, - }, + stream: stream.readable, + message: () => message, + debug: () => ({ + // id: completions.id, + // model: completions.model, + // usage: completions.usage, + // finish: choice.finish_reason, + }), + done, }; } catch (e) { console.warn(`Error from OpenAI for ${c.name}`, e); return { - message: "💩", - debug: { finish: "Failed", error: { message: (e as Error).message } }, + stream: new TextDecoderStream("💩").readable, + message: () => "💩", + debug: () => ({ + finish: "failed", + error: { message: (e as Error).message }, + }), + done: Promise.reject(e), }; } -} +}; async function callOpenAiWithRateLimit( openai: OpenAI, - content: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming -): Promise { + content: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming +): Promise | undefined> { let retry = 3; while (retry > 0) { retry -= 1; diff --git a/core/src/plugin/index.ts b/core/src/plugin/index.ts index 4e486c1..3db4e22 100644 --- a/core/src/plugin/index.ts +++ b/core/src/plugin/index.ts @@ -20,7 +20,7 @@ export const PLUGINS: Record = { }; export async function getPlugin( - name: string + name: keyof typeof PLUGINS | string ): Promise<{ default: PluginBuilder }> { if (name.startsWith("file://")) { return import(name); diff --git a/core/src/util.ts b/core/src/util.ts index 614507b..c6b9629 100644 --- a/core/src/util.ts +++ b/core/src/util.ts @@ -4,3 +4,15 @@ export const LOGGER = getLogger("@ailly/core"); LOGGER.level = getLogLevel(process.env["AILLY_LOG_LEVEL"]); export const isDefined = (t: T | undefined): t is T => t !== undefined; + +Promise.withResolvers = + Promise.withResolvers ?? + function makePromise(): PromiseWithResolvers { + let resolve: (t: T | PromiseLike) => void = () => {}; + let reject: (reason?: any) => void = () => {}; + const promise = new Promise((r, j) => { + resolve = r; + reject = j; + }); + return { promise, resolve, reject }; + }; diff --git a/integ/06_stream/stream.sh b/integ/06_stream/stream.sh new file mode 100755 index 0000000..5567cc7 --- /dev/null +++ b/integ/06_stream/stream.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +cd $(dirname $0) +set -x +set -e + +AILLY_NOOP_STREAM=y ailly --prompt "Tell me a joke" --stream | tee out +[ -s out ] +rm out diff --git a/integ/integ.sh b/integ/integ.sh index c320378..172a47e 100755 --- a/integ/integ.sh +++ b/integ/integ.sh @@ -34,6 +34,9 @@ rm 04_edit/{err,out} echo "conversations" ./05_conversation/conversation.sh +echo "Stream" +./06_stream/stream.sh + echo "Pipes" ./10_std_pipes/pipes.sh