Skip to content

Commit

Permalink
add multi-task support
Browse files Browse the repository at this point in the history
  • Loading branch information
NawarA committed Aug 6, 2024
1 parent 025b0b0 commit 1f28985
Show file tree
Hide file tree
Showing 11 changed files with 854 additions and 27 deletions.
23 changes: 19 additions & 4 deletions javascript/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
modelOptionMapper,
ModelOptions
} from "./interface/model";
import { Task } from "./interface/task";

class Client {
constructor(apiKey: string) {
Expand Down Expand Up @@ -49,8 +50,11 @@ export default class Bytez {
#client: Client;

list = {
/** Lists the currently available models, and provides basic information about each one, such as RAM required */
models: () => this.#client._request("model/list"),
models: {
/** Lists the currently available models, and provides basic information about each one, such as RAM required */
all: () => this.#client._request("model/list"),
byTask: (task: Task) => this.#client._request(`model/list?task=${task}`)
},
/** List your serverless instances */
instances: () => this.#client._request("model/instances")
};
Expand All @@ -77,8 +81,13 @@ class Model {
this.id = modelId;
this.options = modelOptionDefaults;
this.#body = { model: this.id };

// this.#client._request(`model/get/${modelId}`).then(({ task }) => {
// this.#task = task;
// });
}
#client: Client;
// #task: string;
#body: PostBody;
/** The HuggingFace modelId, for example `openai-community/gpt2` */
id: string;
Expand All @@ -93,7 +102,7 @@ class Model {
async load(options?: ModelOptions): Promise<any> {
let { status, error } = await this.start(options);

status ??= error.includes("already loaded") ? "RUNNING" : "";
status ??= error?.includes?.("already loaded") ? "RUNNING" : "";

while (status !== "FAILED" && status !== "RUNNING") {
({ status } = await this.status());
Expand Down Expand Up @@ -132,7 +141,13 @@ class Model {
/** Run model */
run(input: any, options: Inference = {}) {
const { stream = false, ...params } = options;
const postBody = { ...this.#body, input, params, stream };
let postBody = { stream, params, ...this.#body };

if (input?.constructor === Object) {
postBody = { ...postBody, ...input };
} else {
postBody.input = input;
}

return this.#client._request("model/run", postBody);
}
Expand Down
2 changes: 1 addition & 1 deletion javascript/src/interface/postBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export default interface PostBody {
* text-generation: string
* chat-model: [role: string, message: string]
*/
input?: string | any;
input?: any;
/**
* Stream back text
*
Expand Down
31 changes: 31 additions & 0 deletions javascript/src/interface/task.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
export type Task =
| "audio-classification"
| "automatic-speech-recognition"
| "depth-estimation"
| "document-question-answering"
| "feature-extraction"
| "fill-mask"
| "image-classification"
| "image-feature-extraction"
| "image-segmentation"
| "image-to-text"
| "mask-generation"
| "object-detection"
| "question-answering"
| "sentence-similarity"
| "summarization"
| "text-classification"
| "text-generation"
| "text-to-audio"
| "text-to-image"
| "text-to-speech"
| "text-to-video"
| "text2text-generation"
| "token-classification"
| "translation"
| "unconditional-image-generation"
| "video-classification"
| "visual-question-answering"
| "zero-shot-classification"
| "zero-shot-image-classification"
| "zero-shot-object-detection";
Loading

0 comments on commit 1f28985

Please sign in to comment.