Skip to content

Commit

Permalink
feat(openai): Refactor to allow easier subclassing (#7598)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Jan 25, 2025
1 parent 67ca4f9 commit cb335ba
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 145 deletions.
289 changes: 144 additions & 145 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,145 +145,6 @@ export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum {
}
}

function openAIResponseToChatMessage(
message: OpenAIClient.Chat.Completions.ChatCompletionMessage,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletion,
includeRawResponse?: boolean
): BaseMessage {
const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as
| OpenAIToolCall[]
| undefined;
switch (message.role) {
case "assistant": {
const toolCalls = [];
const invalidToolCalls = [];
for (const rawToolCall of rawToolCalls ?? []) {
try {
toolCalls.push(parseToolCall(rawToolCall, { returnId: true }));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message));
}
}
const additional_kwargs: Record<string, unknown> = {
function_call: message.function_call,
tool_calls: rawToolCalls,
};
if (includeRawResponse !== undefined) {
additional_kwargs.__raw_response = rawResponse;
}
const response_metadata: Record<string, unknown> | undefined = {
model_name: rawResponse.model,
...(rawResponse.system_fingerprint
? {
usage: { ...rawResponse.usage },
system_fingerprint: rawResponse.system_fingerprint,
}
: {}),
};

if (message.audio) {
additional_kwargs.audio = message.audio;
}

return new AIMessage({
content: message.content || "",
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs,
response_metadata,
id: rawResponse.id,
});
}
default:
return new ChatMessage(message.content || "", message.role ?? "unknown");
}
}

function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletionChunk,
defaultRole?: OpenAIRoleEnum,
includeRawResponse?: boolean
) {
const role = delta.role ?? defaultRole;
const content = delta.content ?? "";
let additional_kwargs: Record<string, unknown>;
if (delta.function_call) {
additional_kwargs = {
function_call: delta.function_call,
};
} else if (delta.tool_calls) {
additional_kwargs = {
tool_calls: delta.tool_calls,
};
} else {
additional_kwargs = {};
}
if (includeRawResponse) {
additional_kwargs.__raw_response = rawResponse;
}

if (delta.audio) {
additional_kwargs.audio = {
...delta.audio,
index: rawResponse.choices[0].index,
};
}

const response_metadata = { usage: { ...rawResponse.usage } };
if (role === "user") {
return new HumanMessageChunk({ content, response_metadata });
} else if (role === "assistant") {
const toolCallChunks: ToolCallChunk[] = [];
if (Array.isArray(delta.tool_calls)) {
for (const rawToolCall of delta.tool_calls) {
toolCallChunks.push({
name: rawToolCall.function?.name,
args: rawToolCall.function?.arguments,
id: rawToolCall.id,
index: rawToolCall.index,
type: "tool_call_chunk",
});
}
}
return new AIMessageChunk({
content,
tool_call_chunks: toolCallChunks,
additional_kwargs,
id: rawResponse.id,
response_metadata,
});
} else if (role === "system") {
return new SystemMessageChunk({ content, response_metadata });
} else if (role === "developer") {
return new SystemMessageChunk({
content,
response_metadata,
additional_kwargs: {
__openai_role__: "developer",
},
});
} else if (role === "function") {
return new FunctionMessageChunk({
content,
additional_kwargs,
name: delta.name,
response_metadata,
});
} else if (role === "tool") {
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: delta.tool_call_id,
response_metadata,
});
} else {
return new ChatMessageChunk({ content, role, response_metadata });
}
}

// Used in LangSmith, export is important here
export function _convertMessagesToOpenAIParams(
messages: BaseMessage[],
Expand Down Expand Up @@ -1290,6 +1151,146 @@ export class ChatOpenAI<
return params;
}

protected _convertOpenAIChatCompletionMessageToBaseMessage(
message: OpenAIClient.Chat.Completions.ChatCompletionMessage,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletion
): BaseMessage {
const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as
| OpenAIToolCall[]
| undefined;
switch (message.role) {
case "assistant": {
const toolCalls = [];
const invalidToolCalls = [];
for (const rawToolCall of rawToolCalls ?? []) {
try {
toolCalls.push(parseToolCall(rawToolCall, { returnId: true }));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message));
}
}
const additional_kwargs: Record<string, unknown> = {
function_call: message.function_call,
tool_calls: rawToolCalls,
};
if (this.__includeRawResponse !== undefined) {
additional_kwargs.__raw_response = rawResponse;
}
const response_metadata: Record<string, unknown> | undefined = {
model_name: rawResponse.model,
...(rawResponse.system_fingerprint
? {
usage: { ...rawResponse.usage },
system_fingerprint: rawResponse.system_fingerprint,
}
: {}),
};

if (message.audio) {
additional_kwargs.audio = message.audio;
}

return new AIMessage({
content: message.content || "",
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs,
response_metadata,
id: rawResponse.id,
});
}
default:
return new ChatMessage(
message.content || "",
message.role ?? "unknown"
);
}
}

protected _convertOpenAIDeltaToBaseMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletionChunk,
defaultRole?: OpenAIRoleEnum
) {
const role = delta.role ?? defaultRole;
const content = delta.content ?? "";
let additional_kwargs: Record<string, unknown>;
if (delta.function_call) {
additional_kwargs = {
function_call: delta.function_call,
};
} else if (delta.tool_calls) {
additional_kwargs = {
tool_calls: delta.tool_calls,
};
} else {
additional_kwargs = {};
}
if (this.__includeRawResponse) {
additional_kwargs.__raw_response = rawResponse;
}

if (delta.audio) {
additional_kwargs.audio = {
...delta.audio,
index: rawResponse.choices[0].index,
};
}

const response_metadata = { usage: { ...rawResponse.usage } };
if (role === "user") {
return new HumanMessageChunk({ content, response_metadata });
} else if (role === "assistant") {
const toolCallChunks: ToolCallChunk[] = [];
if (Array.isArray(delta.tool_calls)) {
for (const rawToolCall of delta.tool_calls) {
toolCallChunks.push({
name: rawToolCall.function?.name,
args: rawToolCall.function?.arguments,
id: rawToolCall.id,
index: rawToolCall.index,
type: "tool_call_chunk",
});
}
}
return new AIMessageChunk({
content,
tool_call_chunks: toolCallChunks,
additional_kwargs,
id: rawResponse.id,
response_metadata,
});
} else if (role === "system") {
return new SystemMessageChunk({ content, response_metadata });
} else if (role === "developer") {
return new SystemMessageChunk({
content,
response_metadata,
additional_kwargs: {
__openai_role__: "developer",
},
});
} else if (role === "function") {
return new FunctionMessageChunk({
content,
additional_kwargs,
name: delta.name,
response_metadata,
});
} else if (role === "tool") {
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: delta.tool_call_id,
response_metadata,
});
} else {
return new ChatMessageChunk({ content, role, response_metadata });
}
}

/** @ignore */
_identifyingParams(): Omit<
OpenAIClient.Chat.ChatCompletionCreateParams,
Expand Down Expand Up @@ -1335,11 +1336,10 @@ export class ChatOpenAI<
if (!delta) {
continue;
}
const chunk = _convertDeltaToMessageChunk(
const chunk = this._convertOpenAIDeltaToBaseMessageChunk(
delta,
data,
defaultRole,
this.__includeRawResponse
defaultRole
);
defaultRole = delta.role ?? defaultRole;
const newTokenIndices = {
Expand Down Expand Up @@ -1576,10 +1576,9 @@ export class ChatOpenAI<
const text = part.message?.content ?? "";
const generation: ChatGeneration = {
text,
message: openAIResponseToChatMessage(
message: this._convertOpenAIChatCompletionMessageToBaseMessage(
part.message ?? { role: "assistant" },
data,
this.__includeRawResponse
data
),
};
generation.generationInfo = {
Expand Down
58 changes: 58 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import { CallbackManager } from "@langchain/core/callbacks/manager";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { InMemoryCache } from "@langchain/core/caches";
import { concat } from "@langchain/core/utils/stream";
import {
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
} from "openai/resources/index.mjs";
import { ChatOpenAI } from "../chat_models.js";

// Save the original value of the 'LANGCHAIN_CALLBACKS_BACKGROUND' environment variable
Expand Down Expand Up @@ -1227,3 +1232,56 @@ test("Allows developer messages with o1", async () => {
]);
expect(res.content).toEqual("testing");
});

test.skip("Allow overriding", async () => {
class ChatDeepSeek extends ChatOpenAI {
protected override _convertOpenAIDeltaToBaseMessageChunk(
delta: Record<string, any>,
rawResponse: ChatCompletionChunk,
defaultRole?:
| "function"
| "user"
| "system"
| "developer"
| "assistant"
| "tool"
) {
const messageChunk = super._convertOpenAIDeltaToBaseMessageChunk(
delta,
rawResponse,
defaultRole
);
messageChunk.additional_kwargs.reasoning_content =
delta.reasoning_content;
return messageChunk;
}

protected override _convertOpenAIChatCompletionMessageToBaseMessage(
message: ChatCompletionMessage,
rawResponse: ChatCompletion
) {
const langChainMessage =
super._convertOpenAIChatCompletionMessageToBaseMessage(
message,
rawResponse
);
langChainMessage.additional_kwargs.reasoning_content = (
message as any
).reasoning_content;
return langChainMessage;
}
}
const model = new ChatDeepSeek({
model: "deepseek-reasoner",
configuration: {
baseURL: "https://api.deepseek.com",
},
apiKey: process.env.DEEPSEEK_API_KEY,
});
const res = await model.invoke("what color is the sky?");
console.log(res);
const stream = await model.stream("what color is the sky?");
for await (const chunk of stream) {
console.log(chunk);
}
});

0 comments on commit cb335ba

Please sign in to comment.