From 424b005b4f982279a5fdc98e366a0fcdc39ee7f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Caddet?= Date: Sun, 25 Aug 2024 18:28:48 +0200 Subject: [PATCH] feat(Agent)!: Add prefix message feature Now you can set a prefix to control the assistant output start. It is the oportunity to improve the `AgentMessage` type. It's not a struct anymore. It became an enum having a variant per possible roles. --- examples/agent.rs | 11 ++--- examples/agent_async.rs | 11 ++--- examples/agent_with_function_calling.rs | 10 ++-- examples/agent_with_function_calling_async.rs | 10 ++-- src/v1/agent.rs | 47 +++++++++++-------- 5 files changed, 45 insertions(+), 44 deletions(-) diff --git a/examples/agent.rs b/examples/agent.rs index f7e9489..dfc8285 100644 --- a/examples/agent.rs +++ b/examples/agent.rs @@ -1,5 +1,5 @@ use mistralai_client::v1::{ - agent::{AgentMessage, AgentMessageRole, AgentParams}, + agent::{AgentMessage, AgentParams}, client::Client, }; @@ -9,11 +9,10 @@ fn main() { let agid = std::env::var("MISTRAL_API_AGENT") .expect("Please export MISTRAL_API_AGENT with the agent id you want to use"); - let messages = vec![AgentMessage { - role: AgentMessageRole::User, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), - tool_calls: None, - }]; + let messages = vec![ + AgentMessage::new_user_message("What's the best city in the world?"), + AgentMessage::new_prefix("Valpo "), + ]; let options = AgentParams { random_seed: Some(42), ..Default::default() diff --git a/examples/agent_async.rs b/examples/agent_async.rs index b0c2d90..28e44cf 100644 --- a/examples/agent_async.rs +++ b/examples/agent_async.rs @@ -1,5 +1,5 @@ use mistralai_client::v1::{ - agent::{AgentMessage, AgentMessageRole, AgentParams}, + agent::{AgentMessage, AgentParams}, client::Client, }; @@ -10,11 +10,10 @@ async fn main() { let agid = std::env::var("MISTRAL_API_AGENT") .expect("Please export MISTRAL_API_AGENT with the agent id you want to use"); - let messages = vec![AgentMessage { - role: AgentMessageRole::User, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), - tool_calls: None, - }]; + let messages = vec![ + AgentMessage::new_user_message("What's the best city in the world?"), + AgentMessage::new_prefix("Valpo "), + ]; let options = AgentParams { random_seed: Some(42), ..Default::default() diff --git a/examples/agent_with_function_calling.rs b/examples/agent_with_function_calling.rs index bd85241..fbeb022 100644 --- a/examples/agent_with_function_calling.rs +++ b/examples/agent_with_function_calling.rs @@ -1,5 +1,5 @@ use mistralai_client::v1::{ - agent::{AgentMessage, AgentMessageRole, AgentParams}, + agent::{AgentMessage, AgentParams}, client::Client, tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, }; @@ -47,11 +47,9 @@ fn main() { let agid = std::env::var("MISTRAL_API_AGENT") .expect("Please export MISTRAL_API_AGENT with the agent id you want to use"); - let messages = vec![AgentMessage { - role: AgentMessageRole::User, - content: "What's the temperature in Paris?".to_string(), - tool_calls: None, - }]; + let messages = vec![AgentMessage::new_user_message( + "What's the temperature in Paris?", + )]; let options = AgentParams { random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), diff --git a/examples/agent_with_function_calling_async.rs b/examples/agent_with_function_calling_async.rs index 58feb02..a7d4969 100644 --- a/examples/agent_with_function_calling_async.rs +++ b/examples/agent_with_function_calling_async.rs @@ -1,5 +1,5 @@ use mistralai_client::v1::{ - agent::{AgentMessage, AgentMessageRole, AgentParams}, + agent::{AgentMessage, AgentParams}, client::Client, tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, }; @@ -48,11 +48,9 @@ async fn main() { let agid = std::env::var("MISTRAL_API_AGENT") .expect("Please export MISTRAL_API_AGENT with the agent id you want to use"); - let messages = vec![AgentMessage { - role: AgentMessageRole::User, - content: "What's the temperature in Paris?".to_string(), - tool_calls: None, - }]; + let messages = vec![AgentMessage::new_user_message( + "What's the temperature in Paris?", + )]; let options = AgentParams { random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), diff --git a/src/v1/agent.rs b/src/v1/agent.rs index 8b94ed7..afedc63 100644 --- a/src/v1/agent.rs +++ b/src/v1/agent.rs @@ -9,41 +9,48 @@ use crate::v1::{ // Definitions #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct AgentMessage { - pub role: AgentMessageRole, - pub content: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, +#[serde(tag = "role", rename_all = "lowercase")] +pub enum AgentMessage { + Assistant { + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + prefix: bool, + }, + User { + content: String, + }, + Tool { + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, } impl AgentMessage { pub fn new_assistant_message(content: &str, tool_calls: Option>) -> Self { - Self { - role: AgentMessageRole::Assistant, + Self::Assistant { content: content.to_string(), tool_calls, + prefix: false, } } pub fn new_user_message(content: &str) -> Self { - Self { - role: AgentMessageRole::User, + Self::User { + content: content.to_string(), + } + } + pub fn new_prefix(content: &str) -> Self { + Self::Assistant { content: content.to_string(), tool_calls: None, + prefix: true, } } } -/// See the [Mistral AI API documentation](https://docs.mistral.ai/capabilities/completion/#chat-messages) for more information. -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum AgentMessageRole { - #[serde(rename = "assistant")] - Assistant, - #[serde(rename = "user")] - User, - #[serde(rename = "tool")] - Tool, -} - /// The format that the model must output. /// /// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.