Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make anthropic work #8

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/models/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,35 @@ pub struct ChatCompletionRequest {
pub user: Option<String>,
}

#[derive(Deserialize, Serialize, Clone)]
#[serde(untagged)]
pub enum ChatMessageContent {
String(String),
Array(Vec<ChatMessageContentPart>),
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatMessageContentPart {
#[serde(rename = "type")]
pub r#type: String,
pub text: String,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionMessage {
pub role: String,
pub content: String,
pub content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub object: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<u64>,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
Expand Down
185 changes: 67 additions & 118 deletions src/providers/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use axum::async_trait;
use axum::http::StatusCode;
use serde::{Deserialize, Serialize};

use super::provider::Provider;
use crate::config::models::{ModelConfig, Provider as ProviderConfig};
use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::models::common::Usage;
use crate::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse};
use crate::models::embeddings::{
Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse,
use crate::models::chat::{
ChatCompletionChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ChatMessageContentPart,
};
use crate::models::common::Usage;
use crate::models::completion::{CompletionRequest, CompletionResponse};
use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
use reqwest::Client;

pub struct AnthropicProvider {
Expand All @@ -17,6 +19,27 @@ pub struct AnthropicProvider {
http_client: Client,
}

#[derive(Deserialize, Serialize, Clone)]
struct AnthropicContent {
pub text: String,
#[serde(rename = "type")]
pub r#type: String,
}

#[derive(Deserialize, Serialize, Clone)]
struct AnthropicChatCompletionResponse {
pub id: String,
pub model: String,
pub content: Vec<AnthropicContent>,
pub usage: AnthropicUsage,
}

#[derive(Deserialize, Serialize, Clone)]
struct AnthropicUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}

#[async_trait]
impl Provider for AnthropicProvider {
fn new(config: &ProviderConfig) -> Self {
Expand All @@ -43,7 +66,7 @@ impl Provider for AnthropicProvider {
let response = self
.http_client
.post("https://api.anthropic.com/v1/messages")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&payload)
.send()
Expand All @@ -52,134 +75,60 @@ impl Provider for AnthropicProvider {

let status = response.status();
if status.is_success() {
response
let anthropic_response: AnthropicChatCompletionResponse = response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.expect("Failed to parse Anthropic response");

Ok(ChatCompletionResponse {
id: anthropic_response.id,
object: None,
created: None,
model: anthropic_response.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatCompletionMessage {
name: None,
role: "assistant".to_string(),
content: crate::models::chat::ChatMessageContent::Array(
anthropic_response
.content
.into_iter()
.map(|content| ChatMessageContentPart {
r#type: content.r#type,
text: content.text,
})
.collect(),
),
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Usage {
prompt_tokens: anthropic_response.usage.input_tokens,
completion_tokens: anthropic_response.usage.output_tokens,
total_tokens: anthropic_response.usage.input_tokens
+ anthropic_response.usage.output_tokens,
},
})
} else {
Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
}
}

async fn completions(
&self,
payload: CompletionRequest,
_payload: CompletionRequest,
_model_config: &ModelConfig,
) -> Result<CompletionResponse, StatusCode> {
let anthropic_payload = serde_json::json!({
"model": payload.model,
"prompt": format!("\n\nHuman: {}\n\nAssistant:", payload.prompt),
"max_tokens_to_sample": payload.max_tokens.unwrap_or(100),
"temperature": payload.temperature.unwrap_or(0.7),
"top_p": payload.top_p.unwrap_or(1.0),
"stop_sequences": payload.stop.unwrap_or_default(),
});

let response = self
.http_client
.post("https://api.anthropic.com/v1/complete")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("anthropic-version", "2023-06-01")
.json(&anthropic_payload)
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let status = response.status();
if !status.is_success() {
return Err(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
);
}

let anthropic_response: serde_json::Value = response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(CompletionResponse {
id: anthropic_response["completion_id"]
.as_str()
.unwrap_or("")
.to_string(),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: payload.model,
choices: vec![CompletionChoice {
text: anthropic_response["completion"]
.as_str()
.unwrap_or("")
.to_string(),
index: 0,
logprobs: None,
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
})
unimplemented!()
}

async fn embeddings(
&self,
payload: EmbeddingsRequest,
_payload: EmbeddingsRequest,
_model_config: &ModelConfig,
) -> Result<EmbeddingsResponse, StatusCode> {
let anthropic_payload = match &payload.input {
EmbeddingsInput::Single(text) => serde_json::json!({
"model": payload.model,
"text": text,
}),
EmbeddingsInput::Multiple(texts) => serde_json::json!({
"model": payload.model,
"text": texts,
}),
};

let response = self
.http_client
.post("https://api.anthropic.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("anthropic-version", "2023-06-01")
.json(&anthropic_payload)
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let status = response.status();
if !status.is_success() {
return Err(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
);
}

let anthropic_response: serde_json::Value = response
.json()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let embedding = anthropic_response["embedding"]
.as_array()
.unwrap_or(&Vec::new())
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();

Ok(EmbeddingsResponse {
object: "list".to_string(),
model: payload.model,
data: vec![Embeddings {
object: "embedding".to_string(),
embedding,
index: 0,
}],
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
})
unimplemented!()
}
}