From 2976782e0b9958502c7b5e640dd7459e8b4d5522 Mon Sep 17 00:00:00 2001 From: igorMSoares Date: Sat, 25 Jan 2025 15:25:55 -0300 Subject: [PATCH] feat: add system prompt support to ReplicateLLM and OpenAI --- packages/providers/openai/src/llm.ts | 10 +++++++++- packages/providers/replicate/src/llm.ts | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/packages/providers/openai/src/llm.ts b/packages/providers/openai/src/llm.ts index 492db09905..09c3c6a6c0 100644 --- a/packages/providers/openai/src/llm.ts +++ b/packages/providers/openai/src/llm.ts @@ -155,6 +155,7 @@ export class OpenAI extends ToolCallLLM { topP: number; maxTokens?: number | undefined; additionalChatOptions?: OpenAIAdditionalChatOptions | undefined; + systemPrompt?: string | undefined; // OpenAI session params apiKey?: string | undefined = undefined; @@ -185,6 +186,7 @@ export class OpenAI extends ToolCallLLM { this.temperature = init?.temperature ?? 0.1; this.topP = init?.topP ?? 1; this.maxTokens = init?.maxTokens ?? undefined; + this.systemPrompt = init?.systemPrompt ?? undefined; this.maxRetries = init?.maxRetries ?? 10; this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds @@ -337,12 +339,18 @@ export class OpenAI extends ToolCallLLM { | AsyncIterable> > { const { messages, stream, tools, additionalChatOptions } = params; + const systemMessage: ChatMessage | undefined = + this.systemPrompt + ? { role: "system", content: this.systemPrompt } + : undefined; const baseRequestParams = { model: this.model, temperature: this.temperature, max_tokens: this.maxTokens, tools: tools?.map(OpenAI.toTool), - messages: OpenAI.toOpenAIMessage(messages), + messages: OpenAI.toOpenAIMessage( + systemMessage ? [systemMessage].concat(messages) : messages, + ), top_p: this.topP, ...Object.assign({}, this.additionalChatOptions, additionalChatOptions), }; diff --git a/packages/providers/replicate/src/llm.ts b/packages/providers/replicate/src/llm.ts index 7012af8c02..efc2966480 100644 --- a/packages/providers/replicate/src/llm.ts +++ b/packages/providers/replicate/src/llm.ts @@ -111,6 +111,7 @@ export class ReplicateLLM extends BaseLLM { topP: number; maxTokens?: number; replicateSession: ReplicateSession; + systemPrompt?: string | undefined; constructor(init?: Partial & { noWarn?: boolean }) { super(); @@ -135,6 +136,7 @@ export class ReplicateLLM extends BaseLLM { init?.maxTokens ?? ALL_AVAILABLE_REPLICATE_MODELS[this.model].contextWindow; // For Replicate, the default is 500 tokens which is too low. this.replicateSession = init?.replicateSession ?? new ReplicateSession(); + this.systemPrompt = init?.systemPrompt ?? undefined; } get metadata() { @@ -175,6 +177,12 @@ export class ReplicateLLM extends BaseLLM { } mapMessagesToPromptLlama3(messages: ChatMessage[]) { + const systemPrompt = this.systemPrompt + ? "<|begin_of_text|><|start_header_id|>system<|end_header_id|>" + + this.systemPrompt + + "<|eot_id|>" + : undefined; + return { prompt: "<|begin_of_text|>" + @@ -196,7 +204,7 @@ export class ReplicateLLM extends BaseLLM { ); }, "") + "<|start_header_id|>assistant<|end_header_id|>\n\n", - systemPrompt: undefined, + systemPrompt, }; }