diff --git a/lua/gp/config.lua b/lua/gp/config.lua index 768095c6..ff2b535c 100644 --- a/lua/gp/config.lua +++ b/lua/gp/config.lua @@ -49,7 +49,7 @@ local config = { endpoint = "http://localhost:1234/v1/chat/completions", }, googleai = { - endpoint = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key={{secret}}", + endpoint = "https://generativelanguage.googleapis.com/v1beta/models/{{model}}:streamGenerateContent?key={{secret}}", secret = os.getenv("GOOGLEAI_API_KEY"), }, }, @@ -128,7 +128,7 @@ local config = { chat = true, command = false, -- string with model name or table with model name and parameters - model = { model = "dummy", temperature = 1.1, top_p = 1 }, + model = { model = "gemini-pro", temperature = 1.1, top_p = 1 }, -- system prompt (use this to specify the persona/role of the AI) system_prompt = "You are a general AI assistant.\n\n" .. "The user provided the additional info about how they would like you to respond:\n\n" diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 28d8c585..d1e71911 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -1215,6 +1215,7 @@ M.prepare_payload = function(messages, model, default_model, provider) topP = math.max(0, math.min(1, model.top_p or 1)), topK = model.top_k or 100, }, + model = model.model, } return payload end @@ -1407,6 +1408,8 @@ M.query = function(buf, provider, payload, handler, on_exit) if provider == "googleai" then headers = {} endpoint = M._H.template_replace(endpoint, "{{secret}}", bearer) + endpoint = M._H.template_replace(endpoint, "{{model}}", payload.model) + payload.model = nil end if provider == "azure" then