Skip to content

Commit

Permalink
feat: add azure openai provider (#960)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahau-square authored Jan 31, 2025
1 parent 092d871 commit 5f6c85d
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 94 deletions.
8 changes: 7 additions & 1 deletion crates/goose-server/src/routes/providers_and_keys.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,11 @@
"description": "Lorem ipsum",
"models": [],
"required_keys": ["OPENROUTER_API_KEY"]
},
"azure_openai": {
"name": "Azure OpenAI",
"description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
}
}
}
141 changes: 141 additions & 0 deletions crates/goose/src/providers/azure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;

use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat};
use crate::message::Message;
use crate::model::ModelConfig;
use mcp_core::tool::Tool;

pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
pub const AZURE_DOC_URL: &str =
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
pub const AZURE_API_VERSION: &str = "2024-10-21";
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[
"gpt-4o",
"gpt-4o-mini",
"o1",
"o1-mini",
"o1-preview",
"gpt-4",
];

#[derive(Debug, serde::Serialize)]
pub struct AzureProvider {
#[serde(skip)]
client: Client,
endpoint: String,
api_key: String,
deployment_name: String,
model: ModelConfig,
}

impl Default for AzureProvider {
fn default() -> Self {
let model = ModelConfig::new(AzureProvider::metadata().default_model);
AzureProvider::from_env(model).expect("Failed to initialize Azure OpenAI provider")
}
}

impl AzureProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?;
let endpoint: String = config.get("AZURE_OPENAI_ENDPOINT")?;
let deployment_name: String = config.get("AZURE_OPENAI_DEPLOYMENT_NAME")?;

let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;

Ok(Self {
client,
endpoint,
api_key,
deployment_name,
model,
})
}

async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint.trim_end_matches('/'),
self.deployment_name,
AZURE_API_VERSION
);

let response: reqwest::Response = self
.client
.post(&url)
.header("api-key", &self.api_key)
.json(&payload)
.send()
.await?;

handle_response_openai_compat(response).await
}
}

#[async_trait]
impl Provider for AzureProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"azure_openai",
"Azure OpenAI",
"Models through Azure OpenAI Service",
"gpt-4o",
AZURE_OPENAI_KNOWN_MODELS
.iter()
.map(|s| s.to_string())
.collect(),
AZURE_DOC_URL,
vec![
ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None),
ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None),
ConfigKey::new(
"AZURE_OPENAI_DEPLOYMENT_NAME",
true,
false,
Some("Name of your Azure OpenAI deployment"),
),
],
)
}

fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}

#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
let response = self.post(payload.clone()).await?;

let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
tracing::warn!("Failed to get usage data: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
emit_debug_trace(self, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
}
}
3 changes: 3 additions & 0 deletions crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{
anthropic::AnthropicProvider,
azure::AzureProvider,
base::{Provider, ProviderMetadata},
databricks::DatabricksProvider,
google::GoogleProvider,
Expand All @@ -14,6 +15,7 @@ use anyhow::Result;
pub fn providers() -> Vec<ProviderMetadata> {
vec![
AnthropicProvider::metadata(),
AzureProvider::metadata(),
DatabricksProvider::metadata(),
GoogleProvider::metadata(),
GroqProvider::metadata(),
Expand All @@ -27,6 +29,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
match name {
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod anthropic;
pub mod azure;
pub mod base;
pub mod databricks;
pub mod errors;
Expand Down
45 changes: 26 additions & 19 deletions ui/desktop/src/components/settings/ProviderSetupModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ interface ProviderSetupModalProps {
model: string;
endpoint: string;
title?: string;
onSubmit: (apiKey: string) => void;
onSubmit: (configValues: { [key: string]: string }) => void;
onCancel: () => void;
}

Expand All @@ -24,14 +24,14 @@ export function ProviderSetupModal({
onSubmit,
onCancel,
}: ProviderSetupModalProps) {
const [apiKey, setApiKey] = React.useState('');
const keyName = required_keys[provider]?.[0] || 'API Key';
const headerText = `Setup ${provider}`;
const [configValues, setConfigValues] = React.useState<{ [key: string]: string }>({});
const requiredKeys = required_keys[provider] || ['API Key'];
const headerText = title || `Setup ${provider}`;

const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
onSubmit(apiKey);
onSubmit(configValues);
};
const inputType = isSecretKey(keyName) ? 'password' : 'text';

return (
<div className="fixed inset-0 bg-black/20 dark:bg-white/20 backdrop-blur-sm transition-colors animate-[fadein_200ms_ease-in_forwards]">
Expand All @@ -48,20 +48,27 @@ export function ProviderSetupModal({

{/* Form */}
<form onSubmit={handleSubmit}>
<div className="mt-[24px]">
<div>
<Input
type={inputType}
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
placeholder={keyName}
className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900"
required
/>
<div className="flex mt-4 text-gray-600 dark:text-gray-300">
<Lock className="w-6 h-6" />
<span className="text-sm font-light ml-4 mt-[2px]">{`Your API key or host will be stored securely in the keychain and used only for making requests to ${provider}`}</span>
<div className="mt-[24px] space-y-4">
{requiredKeys.map((keyName) => (
<div key={keyName}>
<Input
type={isSecretKey(keyName) ? 'password' : 'text'}
value={configValues[keyName] || ''}
onChange={(e) =>
setConfigValues((prev) => ({
...prev,
[keyName]: e.target.value,
}))
}
placeholder={keyName}
className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900"
required
/>
</div>
))}
<div className="flex text-gray-600 dark:text-gray-300">
<Lock className="w-6 h-6" />
<span className="text-sm font-light ml-4 mt-[2px]">{`Your configuration values will be stored securely in the keychain and used only for making requests to ${provider}`}</span>
</div>
</div>

Expand Down
15 changes: 10 additions & 5 deletions ui/desktop/src/components/settings/api_keys/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ import { Provider, ProviderResponse } from './types';
import { getApiUrl, getSecretKey } from '../../../config';

export function isSecretKey(keyName: string): boolean {
// Ollama and Databricks use host name right now and it should not be stored as secret.
return keyName != 'DATABRICKS_HOST' && keyName != 'OLLAMA_HOST';
// Endpoints and hosts should not be stored as secrets
const nonSecretKeys = [
'DATABRICKS_HOST',
'OLLAMA_HOST',
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME',
];
return !nonSecretKeys.includes(keyName);
}

export async function getActiveProviders(): Promise<string[]> {
Expand All @@ -16,9 +22,8 @@ export async function getActiveProviders(): Promise<string[]> {
.filter((provider) => {
const apiKeyStatus = Object.values(provider.config_status || {}); // Get all key statuses

// Include providers if:
// - They have at least one key set (`is_set: true`)
return apiKeyStatus.some((key) => key.is_set);
// Include providers if all required keys are set
return apiKeyStatus.length > 0 && apiKeyStatus.every((key) => key.is_set);
})
.map((provider) => provider.name || 'Unknown Provider'); // Extract provider name

Expand Down
7 changes: 7 additions & 0 deletions ui/desktop/src/components/settings/models/hardcoded_stuff.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export const goose_models: Model[] = [
{ id: 15, name: 'llama-3.3-70b-versatile', provider: 'Groq' },
{ id: 16, name: 'qwen2.5', provider: 'Ollama' },
{ id: 17, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 18, name: 'gpt-4o', provider: 'Azure OpenAI' },
];

export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
Expand All @@ -42,6 +43,8 @@ export const ollama_mdoels = ['qwen2.5'];

export const openrouter_models = ['anthropic/claude-3.5-sonnet'];

export const azure_openai_models = ['gpt-4o'];

export const default_models = {
openai: 'gpt-4o',
anthropic: 'claude-3-5-sonnet-latest',
Expand All @@ -50,6 +53,7 @@ export const default_models = {
groq: 'llama-3.3-70b-versatile',
openrouter: 'anthropic/claude-3.5-sonnet',
ollama: 'qwen2.5',
azure_openai: 'gpt-4o',
};

export function getDefaultModel(key: string): string | undefined {
Expand All @@ -66,6 +70,7 @@ export const required_keys = {
Ollama: ['OLLAMA_HOST'],
Google: ['GOOGLE_API_KEY'],
OpenRouter: ['OPENROUTER_API_KEY'],
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
};

export const supported_providers = [
Expand All @@ -76,6 +81,7 @@ export const supported_providers = [
'Google',
'Ollama',
'OpenRouter',
'Azure OpenAI',
];

export const model_docs_link = [
Expand All @@ -99,4 +105,5 @@ export const provider_aliases = [
{ provider: 'Databricks', alias: 'databricks' },
{ provider: 'OpenRouter', alias: 'openrouter' },
{ provider: 'Google', alias: 'google' },
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
];
Loading

0 comments on commit 5f6c85d

Please sign in to comment.