diff --git a/README.md b/README.md index fe91dbc7..1dddc4b0 100644 --- a/README.md +++ b/README.md @@ -68,12 +68,17 @@ Get up and running with WhoDB quickly using Docker: docker run -it -p 8080:8080 clidey/whodb ``` + +To run WhoDB with an OpenAI compatible service, you should assign some environments. +```sh +docker run -it -e USE_CUSTOM_MODELS=1 -e CUSTOM_MODELS=gpt-4o,gpt-3.5,others -e OPENAI_BASE_URL=http://your_base_url/v1 -p 8080:8080 clidey/whodb +``` + If you are using a remote Ollama server, please start Docker with Ollama Environments like this: ```sh docker run -it -e WHODB_OLLAMA_HOST=YOUR_OLLAMA_HOST -e WHODB_OLLAMA_PORT=YOUR_OLLAMA_PORT -p 8080:8080 clidey/whodb ``` - Or, use Docker Compose: ```sh version: "3.8" diff --git a/core/src/llm/chatgpt_client.go b/core/src/llm/chatgpt_client.go index 6996804f..77fdab31 100644 --- a/core/src/llm/chatgpt_client.go +++ b/core/src/llm/chatgpt_client.go @@ -9,7 +9,6 @@ import ( "strings" ) -const chatGPTEndpoint = "https://api.openai.com/v1" func prepareChatGPTRequest(c *LLMClient, prompt string, model LLMModel, receiverChan *chan string) (string, []byte, map[string]string, error) { requestBody, err := json.Marshal(map[string]interface{}{ @@ -20,7 +19,7 @@ func prepareChatGPTRequest(c *LLMClient, prompt string, model LLMModel, receiver if err != nil { return "", nil, nil, err } - url := fmt.Sprintf("%v/chat/completions", chatGPTEndpoint) + url := fmt.Sprintf("%v/chat/completions", getOpenAICompatibleBaseURL()) headers := map[string]string{ "Authorization": fmt.Sprintf("Bearer %s", c.APIKey), "Content-Type": "application/json", @@ -29,7 +28,7 @@ func prepareChatGPTRequest(c *LLMClient, prompt string, model LLMModel, receiver } func prepareChatGPTModelsRequest(apiKey string) (string, map[string]string) { - url := fmt.Sprintf("%v/models", chatGPTEndpoint) + url := fmt.Sprintf("%v/models", getOpenAICompatibleBaseURL()) headers := map[string]string{ "Authorization": fmt.Sprintf("Bearer %s", apiKey), "Content-Type": "application/json", diff --git a/core/src/llm/env.go b/core/src/llm/env.go index 6718d2a7..2bb7bb6f 100644 --- a/core/src/llm/env.go +++ b/core/src/llm/env.go @@ -2,7 +2,8 @@ package llm import ( "fmt" - + "strings" + "os" "github.com/clidey/whodb/core/src/common" "github.com/clidey/whodb/core/src/env" ) @@ -24,3 +25,31 @@ func getOllamaEndpoint() string { return fmt.Sprintf("http://%v:%v/api", host, port) } + +func getOpenAICompatibleBaseURL() string { + defaultBaseURL := "https://api.openai.com/v1" + baseURL := os.Getenv("OPENAI_BASE_URL") + if baseURL == "" { + baseURL = defaultBaseURL + } + return baseURL +} + +func getCustomModels() ([]string, error) { + modelsStr := os.Getenv("CUSTOM_MODELS") + if modelsStr == "" { + return []string{}, nil + } + + models := strings.Split(modelsStr, ",") + + for i := range models { + models[i] = strings.TrimSpace(models[i]) + } + return models, nil +} + +func ShouldUseCustomModels() bool { + useCustomModels := os.Getenv("USE_CUSTOM_MODELS") + return useCustomModels == "1" +} \ No newline at end of file diff --git a/core/src/llm/llm_client.go b/core/src/llm/llm_client.go index 709a04ec..557cab27 100644 --- a/core/src/llm/llm_client.go +++ b/core/src/llm/llm_client.go @@ -66,6 +66,9 @@ func (c *LLMClient) GetSupportedModels() ([]string, error) { url, headers = prepareOllamaModelsRequest() case ChatGPT_LLMType: url, headers = prepareChatGPTModelsRequest(c.APIKey) + if ShouldUseCustomModels() { + return getCustomModels() + } case Anthropic_LLMType: return getAnthropicModels(c.APIKey) default: