Skip to content

Commit

Permalink
refactor: introduce task-llm package which breaks out ai/llm tasks fr…
Browse files Browse the repository at this point in the history
…om core
  • Loading branch information
sroussey committed Jan 11, 2025
1 parent ce92973 commit ba69f78
Show file tree
Hide file tree
Showing 34 changed files with 217 additions and 114 deletions.
Binary file modified bun.lockb
Binary file not shown.
7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
],
"scripts": {
"build": "bun run build:packages && bun run build:examples",
"build:packages": "bun run build:core && bun run build:providers && bun run build:cli && bun run build:web",
"build:packages": "bun run build:core && bun run build:task-llm && bun run build:providers",
"build:core": "cd packages/core && bun run build",
"build:task-llm": "cd packages/task-llm && bun run build",
"build:providers": "(cd packages/provider-hf-transformers && bun run build) && (cd packages/provider-tf-mediapipe && bun run build)",
"build:examples": "bun run bun run build:cli && bun run build:web",
"build:cli": "cd samples/cli && bun run build",
"build:web": "cd samples/web && bun run build",
"clean": "rm -rf node_modules packages/*/node_modules packages/*/dist packages/*/src/**\\.d\\.ts packages/*/src/**\\.map examples/*/node_modules examples/*/dist examples/*/src/**\\.d\\.ts examples/*/src/**\\.map",
"watch:packages": "concurrently --kill-others -c 'auto' -n core,hf-trans,tf-media 'cd packages/core && bun run watch' 'sleep 5 && cd packages/provider-hf-transformers && bun run watch' 'sleep 5 && cd packages/provider-tf-mediapipe && bun run watch'",
"clean": "rm -rf node_modules packages/*/node_modules packages/*/dist packages/*/src/**/*\\.d\\.ts packages/*/src/**/*\\.map examples/*/node_modules examples/*/dist examples/*/src/**/*\\.d\\.ts examples/*/src/**/*\\.map",
"watch:packages": "concurrently --kill-others -c 'auto' -n core,task-llm,hf-trans,tf-media 'cd packages/core && bun run watch' 'cd packages/task-llm && bun run watch' 'sleep 5 && cd packages/provider-hf-transformers && bun run watch' 'sleep 5 && cd packages/provider-tf-mediapipe && bun run watch'",
"docs": "typedoc",
"format": "eslint \"{packages|examples}/*/src/**/*.{js,ts,tsx,json}\" --fix && prettier \"{packages|examples}/*/src/**/*.{js,ts,tsx,json}\" --check --write",
"release": "bun run build && bun publish",
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/provider/ProviderRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

import type { JobQueueLlmTask } from "../task/base/JobQueueLlmTask";
import type { ModelProcessorEnum } from "../model/Model";
import { Job, JobConstructorDetails } from "../job/base/Job";
import type { JobQueue } from "../job/base/JobQueue";
import { TaskInput, TaskOutput } from "../task/base/Task";
import { JobQueueTask } from "../task";

export enum JobQueueRunType {
local = "local",
Expand Down Expand Up @@ -46,7 +46,7 @@ export class ProviderRegistry<Input, Output> {

jobAsRunFn(runtype: string, modelType: ModelProcessorEnum) {
const fn = this.runFnRegistry[runtype]?.[modelType];
return async (task: JobQueueLlmTask, input: Input) => {
return async (task: JobQueueTask, input: Input) => {
const queue = this.queues.get(modelType)!;
const job = new ProviderJob({
queueName: queue.queue,
Expand Down
3 changes: 1 addition & 2 deletions packages/core/src/task/base/TaskIOTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

import { similarity_fn } from "../SimilarityTask";
import { Document } from "../../source/Document";

export type AnyNumberArray =
Expand Down Expand Up @@ -141,7 +140,7 @@ interface TsTypes {
log_level: LogLevel;
doc_parser: DocParser;
doc_variant: DocVariant;
similarity_fn: (typeof similarity_fn)[number];
similarity_fn: (a: number[], b: number[]) => number;
document: Document;
}

Expand Down
10 changes: 0 additions & 10 deletions packages/core/src/task/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,7 @@ export * from "./base/TaskGraphBuilder";
export * from "./base/ArrayTask";
export * from "./base/TaskIOTypes";
export * from "./base/JobQueueTask";
export * from "./base/JobQueueLlmTask";
export * from "./DocumentSplitterTask";
export * from "./LambdaTask";
export * from "./DebugLogTask";
export * from "./JsonTask";
export * from "./JavaScriptTask";
export * from "./DownloadModelTask";
export * from "./TextRewriterTask";
export * from "./TextGenerationTask";
export * from "./TextSummaryTask";
export * from "./TextEmbeddingTask";
export * from "./TextQuestionAnswerTask";
export * from "./TextTranslationTask";
export * from "./SimilarityTask";
3 changes: 2 additions & 1 deletion packages/provider-hf-transformers/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"dist"
],
"dependencies": {
"ellmers-core": "workspace:packages/core"
"ellmers-core": "workspace:packages/core",
"ellmers-task-llm": "workspace:packages/task-llm"
}
}
14 changes: 9 additions & 5 deletions packages/provider-hf-transformers/src/bindings/all_inmemory.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import { getProviderRegistry } from "ellmers-core";
import {
getProviderRegistry,
InMemoryJobQueue,
ModelProcessorEnum,
ConcurrencyLimiter,
TaskInput,
TaskOutput,
} from "ellmers-core";
import { registerHuggingfaceLocalTasks } from "./local_hf";
import { InMemoryJobQueue } from "ellmers-core";
import { ModelProcessorEnum } from "ellmers-core";
import { ConcurrencyLimiter } from "ellmers-core";
import { TaskInput, TaskOutput } from "ellmers-core";
import "../model/ONNXModelSamples";

export async function registerHuggingfaceLocalTasksInMemory() {
registerHuggingfaceLocalTasks();
Expand Down
21 changes: 10 additions & 11 deletions packages/provider-hf-transformers/src/bindings/local_hf.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
import { ModelProcessorEnum, getProviderRegistry } from "ellmers-core";
import {
DownloadModelTask,
TextEmbeddingTask,
TextGenerationTask,
TextQuestionAnswerTask,
TextRewriterTask,
TextSummaryTask,
TextTranslationTask,
} from "ellmers-task-llm";
import {
HuggingFaceLocal_DownloadRun,
HuggingFaceLocal_EmbeddingRun,
Expand All @@ -7,17 +17,6 @@ import {
HuggingFaceLocal_TextSummaryRun,
HuggingFaceLocal_TextTranslationRun,
} from "../provider/HuggingFaceLocal_TaskRun";
import { ModelProcessorEnum } from "ellmers-core";
import { getProviderRegistry } from "ellmers-core";
import {
DownloadModelTask,
TextEmbeddingTask,
TextGenerationTask,
TextQuestionAnswerTask,
TextRewriterTask,
TextSummaryTask,
TextTranslationTask,
} from "ellmers-core";

export async function registerHuggingfaceLocalTasks() {
const ProviderRegistry = getProviderRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ import {
type TranslationSingle,
TextStreamer,
} from "@huggingface/transformers";
import { findModelByName } from "ellmers-core";
import { findModelByName, ElVector } from "ellmers-core";
import { ONNXTransformerJsModel } from "../model/ONNXTransformerJsModel";
import {
ElVector,
JobQueueLlmTask,
DownloadModelTask,
DownloadModelTaskInput,
Expand All @@ -46,7 +45,7 @@ import {
TextTranslationTask,
TextTranslationTaskInput,
TextTranslationTaskOutput,
} from "ellmers-core";
} from "ellmers-task-llm";

env.cacheDir = "./.cache";

Expand Down
5 changes: 3 additions & 2 deletions packages/provider-hf-transformers/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
"rootDir": "./src",
"paths": {
"#/*": ["./src/*"],
"ellmers-core": ["../core/src"]
"ellmers-core": ["../core/src"],
"ellmers-task-llm": ["../task-llm/src"]
}
},
"references": [{ "path": "../core" }]
"references": [{ "path": "../core" }, { "path": "../task-llm" }]
}
3 changes: 2 additions & 1 deletion packages/provider-tf-mediapipe/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"dist"
],
"dependencies": {
"ellmers-core": "workspace:packages/core"
"ellmers-core": "workspace:packages/core",
"ellmers-task-llm": "workspace:packages/task-llm"
}
}
5 changes: 2 additions & 3 deletions packages/provider-tf-mediapipe/src/bindings/local_mp.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { ModelProcessorEnum } from "ellmers-core";
import { getProviderRegistry } from "ellmers-core";
import { DownloadModelTask, TextEmbeddingTask } from "ellmers-core";
import { ModelProcessorEnum, getProviderRegistry } from "ellmers-core";
import { DownloadModelTask, TextEmbeddingTask } from "ellmers-task-llm";
import {
MediaPipeTfJsLocal_Download,
MediaPipeTfJsLocal_Embedding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
// *******************************************************************************

import { FilesetResolver, TextEmbedder } from "@mediapipe/tasks-text";
import { DownloadModelTask, DownloadModelTaskInput } from "ellmers-core";
import { TextEmbeddingTask, TextEmbeddingTaskInput } from "ellmers-core";
import { findModelByName } from "ellmers-core";
import { ElVector } from "ellmers-core";
import { findModelByName, ElVector } from "ellmers-core";
import {
DownloadModelTask,
DownloadModelTaskInput,
TextEmbeddingTask,
TextEmbeddingTaskInput,
} from "ellmers-task-llm";
import { MediaPipeTfJsModel } from "../model/MediaPipeModel";

/**
Expand Down
5 changes: 3 additions & 2 deletions packages/provider-tf-mediapipe/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
"rootDir": "./src",
"paths": {
"#/*": ["./src/*"],
"ellmers-core": ["../core/src"]
"ellmers-core": ["../core/src"],
"ellmers-task-llm": ["../task-llm/src"]
}
},
"references": [{ "path": "../core" }]
"references": [{ "path": "../core" }, { "path": "../task-llm" }]
}
37 changes: 37 additions & 0 deletions packages/task-llm/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"name": "ellmers-task-llm",
"type": "module",
"version": "0.0.1",
"description": "Ellmers is a tool for building and running DAG pipelines of AI tasks.",
"scripts": {
"watch": "concurrently -c 'auto' -n 'tf-meda:' 'bun:watch-*'",
"watch-browser": "bun build --watch --target=browser --sourcemap=external --external ellmers-core --outdir ./dist ./src/index.ts",
"watch-worker": "bun build --watch --target=bun --sourcemap=external --external ellmers-core --outdir ./dist ./src/worker.ts",
"watch-types": "tsc --watch --preserveWatchOutput",
"build": "bun run build-clean && bun run build-types && bun run build-browser && bun run build-worker",
"build-clean": "rm -fr dist/* tsconfig.tsbuildinfo",
"build-browser": "bun build --target=browser --sourcemap=external --external ellmers-core --outdir ./dist ./src/index.ts",
"build-worker": "bun build --target=bun --sourcemap=external --external ellmers-core --outdir ./dist ./src/worker.ts",
"build-types": "tsc",
"lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
"test": "bun test"
},
"module": "dist/index.js",
"types": "dist/index.d.ts",
"exports": {
".": {
"import": "./dist/index.js",
"types": "./dist/index.d.ts"
},
"./worker": {
"import": "./dist/worker.js",
"types": "./dist/worker.d.ts"
}
},
"files": [
"dist"
],
"dependencies": {
"ellmers-core": "workspace:packages/core"
}
}
1 change: 1 addition & 0 deletions packages/task-llm/src/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export * from "./task";
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

import { Document, DocumentFragment } from "../source/Document";
import { SingleTask } from "./base/Task";
import { TaskGraphBuilder, TaskGraphBuilderHelper } from "./base/TaskGraphBuilder";
import { CreateMappedType } from "./base/TaskIOTypes";
import { TaskRegistry } from "./base/TaskRegistry";
import {
Document,
DocumentFragment,
SingleTask,
TaskGraphBuilder,
TaskGraphBuilderHelper,
CreateMappedType,
TaskRegistry,
} from "ellmers-core";

export type DocumentSplitterTaskInput = CreateMappedType<typeof DocumentSplitterTask.inputs>;
export type DocumentSplitterTaskOutput = CreateMappedType<typeof DocumentSplitterTask.outputs>;
Expand Down Expand Up @@ -73,7 +77,7 @@ export const DocumentSplitter = (input: DocumentSplitterTaskInput) => {
return DocumentSplitterBuilder(input).run();
};

declare module "./base/TaskGraphBuilder" {
declare module "ellmers-core" {
interface TaskGraphBuilder {
DocumentSplitter: TaskGraphBuilderHelper<DocumentSplitterTask>;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

import { ConvertAllToArrays, ConvertSomeToOptionalArray, arrayTaskFactory } from "./base/ArrayTask";
import { CreateMappedType } from "./base/TaskIOTypes";
import { TaskRegistry } from "./base/TaskRegistry";
import {
TaskGraphBuilder,
TaskGraphBuilderHelper,
CreateMappedType,
TaskRegistry,
ConvertAllToArrays,
ConvertSomeToOptionalArray,
arrayTaskFactory,
TaskOutput,
JobQueueTaskConfig,
findModelByName,
ModelUseCaseEnum,
} from "ellmers-core";
import { JobQueueLlmTask } from "./base/JobQueueLlmTask";
import { TaskOutput } from "./base/Task";
import { JobQueueTaskConfig } from "./base/JobQueueTask";
import { TaskGraphBuilder, TaskGraphBuilderHelper } from "./base/TaskGraphBuilder";
import { ModelUseCaseEnum } from "../model/Model";
import { findModelByName } from "../model/InMemoryStorage";

export type DownloadModelTaskInput = CreateMappedType<typeof DownloadModelTask.inputs>;
export type DownloadModelTaskOutput = CreateMappedType<typeof DownloadModelTask.outputs>;
Expand Down Expand Up @@ -106,6 +111,7 @@ export class DownloadModelTask extends JobQueueLlmTask {
static readonly type = "DownloadModelTask";
static readonly category = "Text Model";
}

TaskRegistry.registerTask(DownloadModelTask);

type DownloadModelCompoundTaskInput = ConvertSomeToOptionalArray<DownloadModelTaskInput, "model">;
Expand All @@ -122,7 +128,7 @@ export const DownloadModel = (input: DownloadModelCompoundTaskInput) => {
}
};

declare module "./base/TaskGraphBuilder" {
declare module "ellmers-core" {
interface TaskGraphBuilder {
DownloadModel: TaskGraphBuilderHelper<DownloadModelCompoundTaskInput>;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

import { SingleTask, TaskConfig, TaskOutput } from "./base/Task";
import { TaskGraphBuilder, TaskGraphBuilderHelper } from "./base/TaskGraphBuilder";
import { CreateMappedType, ElVector, ValueTypesIndex } from "./base/TaskIOTypes";
import { TaskRegistry } from "./base/TaskRegistry";
import {
CreateMappedType,
TaskRegistry,
TaskGraphBuilder,
TaskGraphBuilderHelper,
ElVector,
ValueTypesIndex,
SingleTask,
TaskOutput,
TaskConfig,
} from "ellmers-core";

// ===============================================================================

Expand Down Expand Up @@ -141,7 +148,7 @@ export const Similarity = (input: SimilarityTaskInput) => {
return SimilarityBuilder(input).run();
};

declare module "./base/TaskGraphBuilder" {
declare module "ellmers-core" {
interface TaskGraphBuilder {
Similarity: TaskGraphBuilderHelper<SimilarityTaskInput>;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import {
ConvertSomeToArray,
ConvertSomeToOptionalArray,
arrayTaskFactory,
} from "./base/ArrayTask";
import { CreateMappedType } from "./base/TaskIOTypes";
import { TaskRegistry } from "./base/TaskRegistry";
CreateMappedType,
TaskRegistry,
JobQueueTaskConfig,
TaskGraphBuilder,
TaskGraphBuilderHelper,
} from "ellmers-core";
import { JobQueueLlmTask } from "./base/JobQueueLlmTask";
import { JobQueueTaskConfig } from "./base/JobQueueTask";
import { TaskGraphBuilder, TaskGraphBuilderHelper } from "./base/TaskGraphBuilder";

export type TextEmbeddingTaskInput = CreateMappedType<typeof TextEmbeddingTask.inputs>;
export type TextEmbeddingTaskOutput = CreateMappedType<typeof TextEmbeddingTask.outputs>;
Expand Down Expand Up @@ -60,7 +61,7 @@ export const TextEmbedding = (input: TextEmbeddingCompoundTaskInput) => {
return new TextEmbeddingCompoundTask({ input }).run();
};

declare module "./base/TaskGraphBuilder" {
declare module "ellmers-core" {
interface TaskGraphBuilder {
TextEmbedding: TaskGraphBuilderHelper<TextEmbeddingCompoundTaskInput>;
}
Expand Down
Loading

0 comments on commit ba69f78

Please sign in to comment.