Skip to content

Commit

Permalink
fix: gemini&claude tool call format #795 #766
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Feb 26, 2025
1 parent 6d8d40e commit 13ab0f8
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 89 deletions.
80 changes: 41 additions & 39 deletions dto/openai_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,50 +18,52 @@ type FormatJsonSchema struct {
}

type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
}

type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Function FunctionRequest `json:"function"`
}

type OpenAIFunction struct {
type FunctionRequest struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
Arguments string `json:"arguments,omitempty"`
}

type StreamOptions struct {
Expand Down Expand Up @@ -137,11 +139,11 @@ func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix
}

func (m *Message) ParseToolCalls() []ToolCall {
func (m *Message) ParseToolCalls() []ToolCallRequest {
if m.ToolCalls == nil {
return nil
}
var toolCalls []ToolCall
var toolCalls []ToolCallRequest
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls
}
Expand Down
24 changes: 12 additions & 12 deletions dto/openai_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
}

type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
}

func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
Expand All @@ -90,24 +90,24 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string)
c.ReasoningContent = &s
}

type ToolCall struct {
type ToolCallResponse struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionResponse `json:"function"`
}

func (c *ToolCall) SetIndex(i int) {
func (c *ToolCallResponse) SetIndex(i int) {
c.Index = &i
}

type FunctionCall struct {
type FunctionResponse struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments"`
Arguments string `json:"arguments"` // response
}

type ChatCompletionsStreamResponse struct {
Expand Down
16 changes: 8 additions & 8 deletions relay/channel/claude/relay-claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
Expand All @@ -315,10 +315,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
Expand All @@ -333,8 +333,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
choice.Delta.SetContentString(claudeResponse.Delta.Text)
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson,
},
})
Expand Down Expand Up @@ -382,7 +382,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""

if reqMode == RequestModeCompletion {
Expand All @@ -403,10 +403,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
Expand Down
14 changes: 7 additions & 7 deletions relay/channel/gemini/relay-gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque

// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
Expand Down Expand Up @@ -338,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data
}

func getToolCall(item *GeminiPart) *dto.ToolCall {
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
Expand All @@ -350,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
if err != nil {
return nil
}
return &dto.ToolCall{
return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
Expand All @@ -380,11 +380,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
if len(candidate.Content.Parts) > 0 {
var texts []string
var toolCalls []dto.ToolCall
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil {
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else {
Expand Down Expand Up @@ -457,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
isTools = true
if call := getToolCall(&part); call != nil {
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
Expand Down
32 changes: 16 additions & 16 deletions relay/channel/ollama/dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@ package ollama
import "one-api/dto"

type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}

type Options struct {
Expand Down
8 changes: 1 addition & 7 deletions service/token_counter.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package service

import (
"encoding/json"
"errors"
"fmt"
"image"
Expand Down Expand Up @@ -170,12 +169,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
}
tkm += msgTokens
if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
openaiTools := request.Tools
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
Expand Down

0 comments on commit 13ab0f8

Please sign in to comment.