Skip to content

Commit

Permalink
feat(Agent)!: Add prefix message feature
Browse files Browse the repository at this point in the history
Now you can set a prefix to control the assistant output start.
It is the oportunity to improve the `AgentMessage` type. It's not a struct anymore.
It became an enum having a variant per possible roles.
  • Loading branch information
francois-caddet committed Aug 25, 2024
1 parent cd7416d commit 424b005
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 44 deletions.
11 changes: 5 additions & 6 deletions examples/agent.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use mistralai_client::v1::{
agent::{AgentMessage, AgentMessageRole, AgentParams},
agent::{AgentMessage, AgentParams},
client::Client,
};

Expand All @@ -9,11 +9,10 @@ fn main() {

let agid = std::env::var("MISTRAL_API_AGENT")
.expect("Please export MISTRAL_API_AGENT with the agent id you want to use");
let messages = vec![AgentMessage {
role: AgentMessageRole::User,
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
tool_calls: None,
}];
let messages = vec![
AgentMessage::new_user_message("What's the best city in the world?"),
AgentMessage::new_prefix("Valpo "),
];
let options = AgentParams {
random_seed: Some(42),
..Default::default()
Expand Down
11 changes: 5 additions & 6 deletions examples/agent_async.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use mistralai_client::v1::{
agent::{AgentMessage, AgentMessageRole, AgentParams},
agent::{AgentMessage, AgentParams},
client::Client,
};

Expand All @@ -10,11 +10,10 @@ async fn main() {

let agid = std::env::var("MISTRAL_API_AGENT")
.expect("Please export MISTRAL_API_AGENT with the agent id you want to use");
let messages = vec![AgentMessage {
role: AgentMessageRole::User,
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
tool_calls: None,
}];
let messages = vec![
AgentMessage::new_user_message("What's the best city in the world?"),
AgentMessage::new_prefix("Valpo "),
];
let options = AgentParams {
random_seed: Some(42),
..Default::default()
Expand Down
10 changes: 4 additions & 6 deletions examples/agent_with_function_calling.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use mistralai_client::v1::{
agent::{AgentMessage, AgentMessageRole, AgentParams},
agent::{AgentMessage, AgentParams},
client::Client,
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
};
Expand Down Expand Up @@ -47,11 +47,9 @@ fn main() {

let agid = std::env::var("MISTRAL_API_AGENT")
.expect("Please export MISTRAL_API_AGENT with the agent id you want to use");
let messages = vec![AgentMessage {
role: AgentMessageRole::User,
content: "What's the temperature in Paris?".to_string(),
tool_calls: None,
}];
let messages = vec![AgentMessage::new_user_message(
"What's the temperature in Paris?",
)];
let options = AgentParams {
random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto),
Expand Down
10 changes: 4 additions & 6 deletions examples/agent_with_function_calling_async.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use mistralai_client::v1::{
agent::{AgentMessage, AgentMessageRole, AgentParams},
agent::{AgentMessage, AgentParams},
client::Client,
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
};
Expand Down Expand Up @@ -48,11 +48,9 @@ async fn main() {

let agid = std::env::var("MISTRAL_API_AGENT")
.expect("Please export MISTRAL_API_AGENT with the agent id you want to use");
let messages = vec![AgentMessage {
role: AgentMessageRole::User,
content: "What's the temperature in Paris?".to_string(),
tool_calls: None,
}];
let messages = vec![AgentMessage::new_user_message(
"What's the temperature in Paris?",
)];
let options = AgentParams {
random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto),
Expand Down
47 changes: 27 additions & 20 deletions src/v1/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,48 @@ use crate::v1::{
// Definitions

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AgentMessage {
pub role: AgentMessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
#[serde(tag = "role", rename_all = "lowercase")]
pub enum AgentMessage {
Assistant {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<tool::ToolCall>>,
prefix: bool,
},
User {
content: String,
},
Tool {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
}
impl AgentMessage {
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
Self {
role: AgentMessageRole::Assistant,
Self::Assistant {
content: content.to_string(),
tool_calls,
prefix: false,
}
}

pub fn new_user_message(content: &str) -> Self {
Self {
role: AgentMessageRole::User,
Self::User {
content: content.to_string(),
}
}
pub fn new_prefix(content: &str) -> Self {
Self::Assistant {
content: content.to_string(),
tool_calls: None,
prefix: true,
}
}
}

/// See the [Mistral AI API documentation](https://docs.mistral.ai/capabilities/completion/#chat-messages) for more information.
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum AgentMessageRole {
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "user")]
User,
#[serde(rename = "tool")]
Tool,
}

/// The format that the model must output.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
Expand Down

0 comments on commit 424b005

Please sign in to comment.