From d8d732c2424f680f237d392ab012220b34dc215d Mon Sep 17 00:00:00 2001 From: Michael de Gans Date: Wed, 11 Sep 2024 22:09:24 -0700 Subject: [PATCH 1/7] v0.2.0 Beta - Better Markdown support - Tested tool use with example - Builder pattern for `Requests` - Useful conversion shortcuts - Stream delta enhancements - Coverage, CI Still to add before release is - more coverage - streaming tool example, how to apply deltas - integrate examples in docs --- .github/workflows/tests.yaml | 74 ++++ .gitignore | 3 +- Cargo.toml | 15 +- README.md | 3 + examples/neologism.rs | 2 +- examples/strawberry.rs | 172 ++++++++ src/lib.rs | 8 + src/markdown.rs | 239 +++++++++++ src/model.rs | 16 +- src/request.rs | 744 +++++++++++++++++++++++++++++++- src/request/message.rs | 809 ++++++++++++++++++++++++++++++++--- src/response.rs | 2 +- src/response/message.rs | 34 +- src/stream.rs | 172 +++++--- src/tool.rs | 103 +++++ 15 files changed, 2258 insertions(+), 138 deletions(-) create mode 100644 .github/workflows/tests.yaml create mode 100644 examples/strawberry.rs create mode 100644 src/markdown.rs diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..7504096 --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,74 @@ +# Credit to GitHub Copilot for generating this file +name: Rust CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + rust: [stable, beta, nightly] + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.rust }} + profile: minimal + override: true + + - name: Cache cargo registry + uses: actions/cache@v2 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-registry- + + - name: Cache cargo index + uses: actions/cache@v2 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-index- + + - name: Cache cargo build + uses: actions/cache@v2 + with: + path: target + key: ${{ runner.os }}-cargo-build-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-build- + + - name: Build + run: cargo build --all-features --verbose + + - name: Run tests + run: cargo test --all-features --verbose + + - name: Install tarpaulin + if: matrix.os == 'ubuntu-latest' + run: cargo install cargo-tarpaulin + + - name: Run tarpaulin + if: matrix.os == 'ubuntu-latest' + run: cargo tarpaulin --out Xml --all-features + + - name: Upload coverage to Codecov + if: matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v2 + with: + files: ./cobertura.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index eb1d4ef..0d677be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target Cargo.lock -.vscode \ No newline at end of file +.vscode +cobertura.xml \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index eee2bd3..08fcc51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "misanthropic" -version = "0.1.4" +version = "0.2.0" edition = "2021" authors = ["Michael de Gans "] description = "An async, ergonomic, client for Anthropic's Messages API" @@ -33,6 +33,10 @@ reqwest = { version = "0.12", features = ["json", "stream"] } serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "1" +# markdown support +pulldown-cmark = { version = "0.12", optional = true } +pulldown-cmark-to-cmark = { version = "17", optional = true } +static_assertions = "1" [dev-dependencies] clap = { version = "4", features = ["derive"] } @@ -63,3 +67,12 @@ prompt-caching = ["beta"] log = ["dep:log"] # Use rustls instead of the system SSL, such as OpenSSL. rustls-tls = ["reqwest/rustls-tls"] +# Use `pulldown-cmark` for markdown parsing and `pulldown-cmark-to-cmark` for +# converting to CommonMark. +markdown = ["dep:pulldown-cmark", "dep:pulldown-cmark-to-cmark"] +# Derive PartialEq for all structs and enums. +partialeq = [] + +[[example]] +name = "strawberry" +required-features = ["markdown"] diff --git a/README.md b/README.md index 8d275f1..1717af4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # `misanthropic` +![Build Status](https://github.com/mdegans/misanthropic/actions/workflows/tests.yaml/badge.svg) +[![codecov](https://codecov.io/gh/mdegans/misanthropic/branch/main/graph/badge.svg)](https://codecov.io/gh/your-username/your-repo) + Is an unofficial simple, ergonomic, client for the Anthropic Messages API. ## Usage diff --git a/examples/neologism.rs b/examples/neologism.rs index 800803c..0c7d41c 100644 --- a/examples/neologism.rs +++ b/examples/neologism.rs @@ -54,7 +54,7 @@ async fn main() -> Result<(), Box> { content: args.prompt.into(), }], max_tokens: 1000.try_into().unwrap(), - metadata: serde_json::Value::Null, + metadata: serde_json::Map::new(), stop_sequences: None, stream: None, system: None, diff --git a/examples/strawberry.rs b/examples/strawberry.rs new file mode 100644 index 0000000..99c6f2c --- /dev/null +++ b/examples/strawberry.rs @@ -0,0 +1,172 @@ +//! An example of tool use and tool results. Language models are sometimes +//! unreasonably mocked since they cannot count letters within tokens (because +//! they do not see words as humans do). This example demonstrates how easy it +//! is to overcome this with an assistive device in the form of a tool. + +use std::io::BufRead; + +use clap::Parser; +use misanthropic::{ + json, + markdown::{self, ToMarkdown}, + request::{ + message::{Block, Role}, + Message, + }, + response, tool, Client, Request, Tool, +}; + +/// Count the number of letters in a word (or any string). An example of tool +/// use and tool results. +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// User prompt. + #[arg( + short, + long, + default_value = "Count the number of r's in 'strawberry'" + )] + prompt: String, + /// Show tool use. + #[arg(long)] + verbose: bool, +} + +/// Things that can go wrong. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Assistant did not call the tool. + #[error("Assistant did not call the tool. Response: {message}")] + NoToolCall { message: response::Message }, + /// Assistant called the wrong tool. + #[error("Assistant called the tool incorrectly. Call: {call}")] + MalformedToolCall { call: tool::Use }, +} + +/// Count the number of letters in a word (or any string). +pub fn count_letters(letter: char, string: String) -> usize { + let letter = letter.to_ascii_lowercase(); + let string = string.to_ascii_lowercase(); + + string.chars().filter(|c| *c == letter).count() +} + +/// Handle the tool call. Returns a [`User`] message with the result. +/// +/// [`User`]: Role::User +pub fn handle_tool_call(call: &tool::Use) -> Result { + if call.name != "count_letters" { + return Err(Error::MalformedToolCall { call: call.clone() }); + } + + if let (Some(letter), Some(string)) = ( + call.input["letter"].as_str().and_then(|s| s.chars().next()), + call.input["string"].as_str(), + ) { + let count = count_letters(letter, string.into()); + + Ok(Message { + role: Role::User, + content: Block::ToolResult { + tool_use_id: call.id.clone(), + content: count.to_string().into(), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + // A Content Block is always convertable into Content. + .into(), + }) + } else { + // Optionally, we could always return a Message and inform the Assistant + // that they called the tool incorrectly so they can try again. + Err(Error::MalformedToolCall { call: call.clone() }) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Read the command line arguments. + let args = Args::parse(); + + // Get API key from stdin. + println!("Enter your API key:"); + let key = std::io::stdin().lock().lines().next().unwrap()?; + + // Create a client. `key` will be consumed and zeroized. + let client = Client::new(key)?; + + // Craft our chat request, providing a Tool definition to call + // `count_letters`. In the future this will be derivable from the function + // signature and docstring. Like many things in our API, `Tool` is also + // convertable from a `serde_json::Value`. + let mut chat = Request::default().add_tool(Tool { + name: "count_letters".into(), + description: "Count the number of letters in a word.".into(), + input_schema: json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + }), + #[cfg(feature = "prompt-caching")] + cache_control: None, + // Inform the assistant about their limitations. + }).system("You are a helpful assistant. You cannot count letters in a word by yourself because you see in tokens, not letters. Use the `count_letters` tool to overcome this limitation.") + // Add user input. + .add_message(Message { + role: Role::User, + content: args.prompt.into(), + }); + + // Generate the next message in the chat. + let message = client.message(&chat).await?; + + // Check if the Assistant called the Tool. The `stop_reason` must be + // `ToolUse` and the last `Content` `Block` must be `ToolUse`. + if let Some(call) = message.tool_use() { + let result = handle_tool_call(call)?; + // Append the tool request and result messages to the chat. + chat.messages.push(message.into()); + chat.messages.push(result); + } else { + // The Assistant did not call the tool. This may not be an error if the + // user did not ask for the tool to be used, in which case it could be + // handled as a normal message. + return Err(Error::NoToolCall { message }.into()); + } + + let message = client.message(&chat).await?; + + if args.verbose { + // Append the message and print the entire conversation as Markdown. The + // default display also renders markdown, but without system prompt and + // tool use information. + chat.messages.push(message.into()); + println!( + "{}", + chat.markdown_custom( + &markdown::Options::default() + .with_system() + .with_tool_use() + .with_tool_results() + ) + ); + } else { + // Just print the message content. The response `Message` contains the + // `request::Message` with a `Role` and `Content`. The message can also + // be printed directly, but this will include the `Role` header. + println!("{}", message.message.content); + } + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index a9e60f3..5f0b320 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,10 @@ pub use tool::Tool; pub mod response; pub use response::Response; +#[cfg(feature = "markdown")] +/// Markdown utilities for parsing and rendering. +pub mod markdown; + /// Re-exports of commonly used crates to avoid version conflicts and reduce /// dependency bloat. pub mod exports { @@ -43,6 +47,10 @@ pub mod exports { #[cfg(feature = "log")] pub use log; pub use memsecurity; + #[cfg(feature = "markdown")] + pub use pulldown_cmark; + #[cfg(feature = "markdown")] + pub use pulldown_cmark_to_cmark; pub use reqwest; pub use serde; pub use serde_json; diff --git a/src/markdown.rs b/src/markdown.rs new file mode 100644 index 0000000..70254fc --- /dev/null +++ b/src/markdown.rs @@ -0,0 +1,239 @@ +use std::ops::Deref; + +use serde::{Deserialize, Serialize}; + +/// Default [`Options`] +pub const DEFAULT_OPTIONS: Options = Options { + inner: pulldown_cmark::Options::empty(), + tool_use: false, + tool_results: false, + system: false, +}; + +/// A static reference to the default [`Options`]. +pub static DEFAULT_OPTIONS_REF: &'static Options = &DEFAULT_OPTIONS; + +mod serde_inner { + use super::*; + + pub fn serialize( + options: &pulldown_cmark::Options, + serializer: S, + ) -> Result + where + S: serde::Serializer, + { + options.bits().serialize(serializer) + } + + pub fn deserialize<'de, D>( + deserializer: D, + ) -> Result + where + D: serde::Deserializer<'de>, + { + let bits = u32::deserialize(deserializer)?; + Ok(pulldown_cmark::Options::from_bits_truncate(bits)) + } +} + +/// Options for parsing, generating, and rendering [`Markdown`]. +#[derive(Serialize, Deserialize)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +pub struct Options { + /// Inner [`pulldown_cmark::Options`]. + #[serde(with = "serde_inner")] + pub inner: pulldown_cmark::Options, + /// Whether to include the system prompt + #[serde(default)] + pub system: bool, + /// Whether to include tool uses. + #[serde(default)] + pub tool_use: bool, + /// Whether to include tool results. + #[serde(default)] + pub tool_results: bool, +} + +impl Options { + /// Set [`tool_use`] to true + /// + /// [`tool_use`]: Options::tool_use + pub fn with_tool_use(mut self) -> Self { + self.tool_use = true; + self + } + + /// Set [`tool_results`] to true + /// + /// [`tool_results`]: Options::tool_results + pub fn with_tool_results(mut self) -> Self { + self.tool_results = true; + self + } + + /// Set [`system`] to true + /// + /// [`system`]: Options::system + pub fn with_system(mut self) -> Self { + self.system = true; + self + } +} + +#[cfg(feature = "markdown")] +impl From for Options { + fn from(inner: pulldown_cmark::Options) -> Self { + Options { + inner, + ..Default::default() + } + } +} + +/// A valid, immutable, Markdown string. It has been parsed and rendered. It can +/// be [`Display`]ed or dereferenced as a [`str`]. +/// +/// [`Display`]: std::fmt::Display +#[derive(derive_more::Display)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[display("{text}")] +pub struct Markdown { + text: String, +} + +impl Into for Markdown { + fn into(self) -> String { + self.text + } +} + +impl AsRef for Markdown { + fn as_ref(&self) -> &str { + self.deref().as_ref() + } +} + +impl std::borrow::Borrow for Markdown { + fn borrow(&self) -> &str { + self.as_ref() + } +} + +impl std::ops::Deref for Markdown { + type Target = str; + + fn deref(&self) -> &str { + &self.text + } +} + +impl<'a, T> From for Markdown +where + T: Iterator>, +{ + fn from(events: T) -> Self { + let mut text = String::new(); + + // Unwrap can never panic because the formatter for `String` never + // returns an error. + let _ = pulldown_cmark_to_cmark::cmark(events, &mut text).unwrap(); + + Markdown { text } + } +} + +#[cfg(any(test, feature = "partial_eq"))] +impl PartialEq for Markdown { + fn eq(&self, other: &str) -> bool { + self.text == other + } +} + +/// A trait for types that can be converted to [`Markdown`]. +pub trait ToMarkdown { + /// Render the type to a [`Markdown`] string with [`DEFAULT_OPTIONS`]. + fn markdown(&self) -> Markdown { + self.markdown_custom(DEFAULT_OPTIONS_REF) + } + + /// Render the type to a [`Markdown`] string with custom [`Options`]. + fn markdown_custom(&self, options: &Options) -> Markdown { + self.markdown_events_custom(options).into() + } + + /// Render the markdown to a type implementing [`std::fmt::Write`] with + /// [`DEFAULT_OPTIONS`]. + fn write_markdown( + &self, + writer: &mut dyn std::fmt::Write, + ) -> std::fmt::Result { + self.write_markdown_custom(writer, DEFAULT_OPTIONS_REF) + } + + /// Render the markdown to a type implementing [`std::fmt::Write`] with + /// custom [`Options`]. + fn write_markdown_custom( + &self, + writer: &mut dyn std::fmt::Write, + options: &Options, + ) -> std::fmt::Result { + use pulldown_cmark_to_cmark::cmark; + + let events = self.markdown_events_custom(options); + let _ = cmark(events, writer)?; + Ok(()) + } + + /// Return an iterator of [`pulldown_cmark::Event`]s with + /// [`DEFAULT_OPTIONS`]. + fn markdown_events<'a>( + &'a self, + ) -> Box> + 'a> { + self.markdown_events_custom(DEFAULT_OPTIONS_REF) + } + + /// Return an iterator of [`pulldown_cmark::Event`]s with custom + /// [`Options`]. + fn markdown_events_custom<'a>( + &'a self, + options: &'a Options, + ) -> Box> + 'a>; +} + +static_assertions::assert_obj_safe!(ToMarkdown); + +impl Default for Options { + fn default() -> Self { + DEFAULT_OPTIONS + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::borrow::Borrow; + + #[test] + fn test_options_serde() { + let options = Options::default(); + + let json = serde_json::to_string(&options).unwrap(); + let options2: Options = serde_json::from_str(&json).unwrap(); + + assert!(options == options2); + } + + #[test] + fn test_markdown() { + let expected = "Hello, **world**!"; + let events = pulldown_cmark::Parser::new(&expected); + let markdown: Markdown = events.into(); + let actual: &str = markdown.borrow(); + assert_eq!(actual, expected); + assert!(&markdown == expected); + let markdown: String = markdown.into(); + assert_eq!(markdown, expected); + } +} diff --git a/src/model.rs b/src/model.rs index 11f79f5..94ce7ab 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,10 +1,18 @@ //! [`Model`] to use for inference. - use serde::{Deserialize, Serialize}; /// Model to use for inference. Note that **some features may limit choices**. #[derive( - Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, + Debug, + Default, + Clone, + Copy, + Serialize, + Deserialize, + PartialEq, + Eq, + PartialOrd, + Ord, )] #[serde(rename_all = "snake_case")] pub enum Model { @@ -12,14 +20,14 @@ pub enum Model { #[serde(rename = "claude-3-5-sonnet-20240620")] Sonnet35, /// Opus 3.0. - #[cfg(not(feature = "prompt-caching"))] #[serde(rename = "claude-3-opus-20240229")] Opus30, /// Sonnet 3.0 #[cfg(not(feature = "prompt-caching"))] #[serde(rename = "claude-3-sonnet-20240229")] Sonnet30, - /// Haiku 3.0. + /// Haiku 3.0. This is the default model. + #[default] #[serde(rename = "claude-3-haiku-20240307")] Haiku30, } diff --git a/src/request.rs b/src/request.rs index 065e976..6d52ea9 100644 --- a/src/request.rs +++ b/src/request.rs @@ -2,9 +2,10 @@ //! //! [Anthropic Messages API]: -use std::num::NonZeroU16; +use std::{borrow::Cow, num::NonZeroU16, vec}; use crate::{tool, Model, Tool}; +use message::Content; use serde::{Deserialize, Serialize}; pub mod message; @@ -14,6 +15,7 @@ pub use message::Message; /// /// [Anthropic Messages API]: #[derive(Serialize, Deserialize)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] pub struct Request { /// [`Model`] to use for inference. pub model: Model, @@ -34,17 +36,14 @@ pub struct Request { pub max_tokens: NonZeroU16, /// Optional info about the request, for example, `user_id` to help /// Anthropic detect and prevent abuse. Do not use PII here (email, phone). - /// Use the [`json!`] macro to create this easily. - /// - /// [`json!`]: serde_json::json - #[serde(skip_serializing_if = "serde_json::Value::is_null")] - pub metadata: serde_json::Value, + #[serde(skip_serializing_if = "serde_json::Map::is_empty")] + pub metadata: serde_json::Map, /// Optional stop sequences. If the model generates any of these sequences, /// the completion will stop with [`StopReason::StopSequence`]. /// /// [`StopReason::StopSequence`]: crate::response::StopReason::StopSequence #[serde(skip_serializing_if = "Option::is_none")] - pub stop_sequences: Option>, + pub stop_sequences: Option>>, /// If `true`, the response will be a stream of [`Event`]s. If `false`, the /// response will be a single [`response::Message`]. /// @@ -81,3 +80,734 @@ pub struct Request { #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, } + +impl Default for Request { + fn default() -> Self { + Self { + model: Default::default(), + messages: Default::default(), + max_tokens: NonZeroU16::new(4096).unwrap(), + metadata: Default::default(), + stop_sequences: Default::default(), + stream: Default::default(), + system: Default::default(), + temperature: Default::default(), + tool_choice: Default::default(), + tools: Default::default(), + top_k: Default::default(), + top_p: Default::default(), + } + } +} + +impl Request { + /// Turn streaming on. + /// + /// **Note**: [`Client::stream`] and [`Client::message`] are more ergonomic + /// and will overwrite this setting. + /// + /// [`Client::stream`]: crate::Client::stream + /// [`Client::message`]: crate::Client::message + pub fn stream(mut self) -> Self { + self.stream = Some(true); + self + } + + /// Turn streaming off. + /// + /// **Note**: [`Client::stream`] and [`Client::message`] are more ergonomic + /// and will overwrite this setting. + /// + /// [`Client::stream`]: crate::Client::stream + /// [`Client::message`]: crate::Client::message + pub fn no_stream(mut self) -> Self { + self.stream = Some(false); + self + } + + /// Set the [`model`] to a [`Model`]. + /// + /// [`model`]: Request::model + pub fn model(mut self, model: Model) -> Self { + self.model = model; + self + } + + /// Set the [`messages`] from an iterable. + /// + /// [`messages`]: Request::messages + pub fn messages(mut self, messages: Ms) -> Self + where + M: Into, + Ms: IntoIterator, + { + self.messages = messages.into_iter().map(Into::into).collect(); + self + } + + /// Add a [`Message`] to [`messages`]. + /// + /// [`messages`]: Request::messages + pub fn add_message(mut self, message: M) -> Self + where + M: Into, + { + self.messages.push(message.into()); + self + } + + /// Extend the [`messages`] from an iterable. + /// + /// [`messages`]: Request::messages + pub fn extend_messages(mut self, messages: Ms) -> Self + where + M: Into, + Ms: IntoIterator, + { + self.messages.extend(messages.into_iter().map(Into::into)); + self + } + + /// Set the [`max_tokens`]. If this is reached, the [`StopReason`] will be + /// [`MaxTokens`] in the [`response::Message::stop_reason`]. + /// + /// [`max_tokens`]: Request::max_tokens + /// [`StopReason`]: crate::response::StopReason + /// [`MaxTokens`]: crate::response::StopReason::MaxTokens + /// [`response::Message::stop_reason`]: crate::response::Message::stop_reason + pub fn max_tokens(mut self, max_tokens: NonZeroU16) -> Self { + self.max_tokens = max_tokens; + self + } + + /// Set the [`metadata`] from an iterable of key-value pairs. + /// The values must be serializable to JSON. + /// + /// # Panics + /// - if a value cannot be serialized to JSON. + /// + /// See [`try_metadata`] for a fallible version. + /// + /// [`metadata`]: Request::metadata + pub fn metadata(mut self, metadata: Vs) -> Self + where + S: Into, + V: Serialize, + Vs: IntoIterator, + { + self.metadata = metadata + .into_iter() + .map(|(k, v)| (k.into(), serde_json::to_value(v).unwrap())) + .collect(); + self + } + + /// Set the [`metadata`] from an iterable of key-value pairs. + /// The values must be serializable to JSON. + /// + /// [`metadata`]: Request::metadata + pub fn try_metadata( + mut self, + metadata: Vs, + ) -> Result + where + S: Into, + V: Serialize, + Vs: IntoIterator, + { + let mut map = serde_json::Map::new(); + + for (k, v) in metadata { + map.insert(k.into(), serde_json::to_value(v)?); + } + + self.metadata = map; + + Ok(self) + } + + /// Insert a key-value pair into the metadata. Replace the value if the key + /// already exists. + pub fn insert_metadata( + mut self, + key: S, + value: V, + ) -> Result + where + S: Into, + V: Serialize, + { + self.metadata + .insert(key.into(), serde_json::to_value(value)?); + Ok(self) + } + + /// Set the [`stop_sequences`]. If one is generated, the completion will + /// stop with [`StopReason::StopSequence`] in the + /// [`response::Message::stop_reason`]. + /// + /// [`stop_sequences`]: Request::stop_sequences + /// [`StopReason::StopSequence`]: crate::response::StopReason::StopSequence + /// [`response::Message::stop_reason`]: crate::response::Message::stop_reason + pub fn stop_sequences(mut self, stop_sequences: Ss) -> Self + where + S: Into>, + Ss: IntoIterator, + { + self.stop_sequences = + Some(stop_sequences.into_iter().map(Into::into).collect()); + self + } + + /// Add a stop sequence to [`stop_sequences`]. If one is generated, the + /// completion will stop with [`StopReason::StopSequence`] in the + /// [`response::Message::stop_reason`]. + /// + /// [`stop_sequences`]: Request::stop_sequences + /// [`StopReason::StopSequence`]: crate::response::StopReason::StopSequence + /// [`response::Message::stop_reason`]: crate::response::Message::stop_reason + pub fn stop_sequence(mut self, stop_sequence: S) -> Self + where + S: Into>, + { + self.stop_sequences + .get_or_insert_with(Default::default) + .push(stop_sequence.into()); + self + } + + /// Extend the [`stop_sequences`] from an iterable. If one is generated, the + /// completion will stop with [`StopReason::StopSequence`] in the + /// [`response::Message::stop_reason`]. + /// + /// [`stop_sequences`]: Request::stop_sequences + /// [`StopReason::StopSequence`]: crate::response::StopReason::StopSequence + /// [`response::Message::stop_reason`]: crate::response::Message::stop_reason + pub fn extend_stop_sequences(mut self, stop_sequences: Ss) -> Self + where + S: Into>, + Ss: IntoIterator, + { + self.stop_sequences + .get_or_insert_with(Default::default) + .extend(stop_sequences.into_iter().map(Into::into)); + self + } + + /// Set the [`system`] prompt [`Content`]. This is content that the model + /// will give special attention to. Instructions should be placed here. + /// + /// [`system`]: Request::system + pub fn system(mut self, system: S) -> Self + where + S: Into, + { + self.system = Some(system.into()); + self + } + + /// Add a [`Block`] to the [`system`] prompt [`Content`]. If there is no + /// [`system`] prompt, one will be created with the supplied `block`. + /// + /// Among the types that can convert to a [`Block`] are: + /// * [`str`] slices + /// * [`String`] + /// * [`message::Image`] base64-encoded images + /// + /// With the `image` feature flag: + /// * [`image::RgbaImage`] images (they will be encoded as PNG) + /// * [`image::DynamicImage`] images (they will be converted to RGBA and + /// encoded as PNG) + /// + /// For other image formats, see the [`message::Image::encode`] method, + /// the [`MediaType`] enum, and the image codec feature flags. + /// + /// [`system`]: Request::system + /// [`Block`]: message::Block + /// [`MediaType`]: message::MediaType + pub fn add_system_block(mut self, block: B) -> Self + where + B: Into, + { + match self.system { + Some(mut content) => { + content.push(block); + self.system = Some(content); + } + None => { + // MultiPart doesn't actually need to have multiple parts. + self.system = Some(Content::MultiPart(vec![block.into()])); + } + } + self + } + + /// Set the [`temperature`] to `Some(value)` or [`None`] to use the default. + /// + /// [`temperature`]: Request::temperature + pub fn temperature(mut self, temperature: Option) -> Self { + self.temperature = temperature; + self + } + + /// Set the [`tool::Choice`]. This constrains how the model uses tools. + /// + /// [`tool::Choice`]: crate::tool::Choice + pub fn tool_choice(mut self, choice: tool::Choice) -> Self { + self.tool_choice = Some(choice); + self + } + + /// Set the available [`tools`]. When the [`Model`] uses a [`Tool`], the + /// [`StopReason`] will be [`ToolUse`] in the + /// [`response::Message::stop_reason`] and the final [`Content`] [`Block`] + /// will be [`Block::ToolUse`] with a unique [`id`]. + /// + /// The response may then be provided in a [`Message`] with a [`Role`] of + /// [`User`] and [`Content`] [`Block`] of [`ToolResult`] with matching + /// [`tool_use_id`] to the [`ToolUse::id`]. + /// + /// For a fallible version, see [`try_tools`]. + /// + /// [`tools`]: Request::tools + /// [`Tool`]: crate::Tool + /// [`StopReason`]: crate::response::StopReason + /// [`ToolUse`]: crate::response::StopReason::ToolUse + /// [`response::Message::stop_reason`]: crate::response::Message::stop_reason + /// [`Block::ToolUse`]: crate::request::message::Block::ToolUse + /// [`id`]: crate::request::message::Block::ToolUse::id + /// [`Role`]: crate::request::message::Role + /// [`User`]: crate::request::message::Role::User + /// [`Block`]: crate::request::message::Block + /// [`ToolResult`]: crate::request::message::Block::ToolResult + /// [`tool_use_id`]: crate::request::message::Block::ToolResult::tool_use_id + /// [`ToolUse::id`]: crate::request::message::Block::ToolUse::id + /// [`try_tools`]: Request::try_tools + pub fn tools(mut self, tools: Ts) -> Self + where + T: Into, + Ts: IntoIterator, + { + self.tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Try to set the [`tools`]. When the [`Model`] uses a [`Tool`], the + /// [`StopReason`] will be [`ToolUse`] in the + /// [`response::Message::stop_reason`] and the final [`Content`] [`Block`] + /// will be [`Block::ToolUse`] with a unique [`id`]. + /// + /// The response may then be provided in a [`Message`] with a [`Role`] of + /// [`User`] and [`Content`] [`Block`] of [`ToolResult`] with matching + /// [`tool_use_id`] to the [`ToolUse::id`]. + /// + /// [`tools`]: Request::tools + /// [`Tool`]: crate::Tool + /// [`StopReason`]: crate::response::StopReason + /// [`ToolUse`]: crate::response::StopReason::ToolUse + /// [`response::Message::stop_reason`]: crate::response::Message::stop_reason + /// [`Block::ToolUse`]: crate::request::message::Block::ToolUse + /// [`id`]: crate::request::message::Block::ToolUse::id + /// [`Role`]: crate::request::message::Role + /// [`User`]: crate::request::message::Role::User + /// [`Block`]: crate::request::message::Block + /// [`ToolResult`]: crate::request::message::Block::ToolResult + /// [`tool_use_id`]: crate::request::message::Block::ToolResult::tool_use_id + /// [`ToolUse::id`]: crate::request::message::Block::ToolUse::id + pub fn try_tools(mut self, tools: Ts) -> Result + where + T: TryInto, + Ts: IntoIterator, + { + self.tools = Some( + tools + .into_iter() + .map(TryInto::try_into) + .collect::>()?, + ); + Ok(self) + } + + /// Add a tool to the request. + pub fn add_tool(mut self, tool: T) -> Self + where + T: Into, + { + self.tools + .get_or_insert_with(Default::default) + .push(tool.into()); + self + } + + /// Try to add a tool to the request. Returns an error if the value cannot + /// be converted into a [`Tool`]. + pub fn try_add_tool(mut self, tool: T) -> Result + where + T: TryInto, + { + self.tools + .get_or_insert_with(Default::default) + .push(tool.try_into()?); + Ok(self) + } + + // No extend for tools because it's not very common or useful. If somebody + // really wants this they can submit a PR. + + /// Set the top K tokens to consider for each token. Set to `None` to use + /// the default value. + pub fn top_k(mut self, top_k: Option) -> Self { + self.top_k = top_k; + self + } + + /// Set the top P for nucleus sampling. Set to [`None`] to use the default + /// value. + pub fn top_p(mut self, top_p: Option) -> Self { + self.top_p = top_p; + self + } + + /// Add a cache breakpoint to the end of the prompt, setting `cache_control` + /// to `Ephemeral`. + /// + /// # Notes + /// * Cache breakpoints apply to the full prefix in the order of [`tools`], + /// [`system`], and [`messages`]. To effectively use this method, call it + /// after setting [`tools`] and [`system`] if you have no examples or + /// after setting [`messages`] if you do. + /// * For [`Sonnet35`] and [`Opus30`] models, the prompt must have at least + /// 1024 tokens for this to have an effect. For [`Haiku30`], the minimum + /// is 2048 tokens. + /// * Since this is a beta feature, the API may change in the future, likely + /// to include another form of `cache_control`. + #[cfg(feature = "prompt-caching")] + pub fn cache(mut self) -> Self { + // If there are messages, add a cache breakpoint to the last one. + if let Some(last) = self.messages.last_mut() { + last.content.cache(); + return self; + } + + // If there are no messages, add a cache breakpoint to the system prompt + // if it exists. + if let Some(system) = self.system.as_mut() { + system.cache(); + return self; + } + + // If there are no messages or system prompt, add a cache breakpoint to + // the tools if they exist. + if let Some(tool) = + self.tools.as_mut().map(|tools| tools.last_mut()).flatten() + { + tool.cache(); + return self; + } + + self + } +} + +#[cfg(feature = "markdown")] +impl crate::markdown::ToMarkdown for Request { + fn markdown_events_custom<'a>( + &'a self, + options: &'a crate::markdown::Options, + ) -> Box> + 'a> { + let system = if let Some(system) = self + .system + .as_ref() + .map(|s| s.markdown_events_custom(options)) + { + if options.system { + system + } else { + Box::new(std::iter::empty()) + } + } else { + Box::new(std::iter::empty()) + }; + + let messages = self + .messages + .iter() + .flat_map(move |m| m.markdown_events_custom(options)); + + Box::new(system.chain(messages)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::num::NonZeroU16; + + use crate::request::message::Role; + + const MESSAGE: Message = Message { + role: Role::User, + content: Content::text("Hello"), + }; + + const MESSAGE2: Message = Message { + role: Role::Assistant, + content: Content::text("Hi"), + }; + + const MESSAGES: [Message; 2] = [MESSAGE, MESSAGE2]; + + const STOP_SEQUENCES: [&'static str; 2] = ["stop1", "stop2"]; + + // Credit to GitHub Copilot for the following tests. + + #[test] + fn test_default_request() { + let request = Request::default(); + assert_eq!(request.model, Model::default()); + assert!(request.messages.is_empty()); + assert_eq!(request.max_tokens, NonZeroU16::new(4096).unwrap()); + assert!(request.metadata.is_empty()); + assert!(request.stop_sequences.is_none()); + assert!(request.stream.is_none()); + assert!(request.system.is_none()); + assert!(request.temperature.is_none()); + assert!(request.tool_choice.is_none()); + assert!(request.tools.is_none()); + assert!(request.top_k.is_none()); + assert!(request.top_p.is_none()); + } + + #[test] + fn test_stream_on() { + let request = Request::default().stream(); + assert_eq!(request.stream, Some(true)); + } + + #[test] + fn test_stream_off() { + let request = Request::default().no_stream(); + assert_eq!(request.stream, Some(false)); + } + + #[test] + fn test_set_model() { + let model = Model::default(); + let request = Request::default().model(model); // Model is Copy + assert_eq!(request.model, model); + } + + #[test] + fn test_set_messages() { + let request = Request::default().messages(MESSAGES); + assert_eq!(request.messages, MESSAGES); + } + + #[test] + fn test_add_message() { + let mut request = Request::default(); + request = request.add_message(MESSAGE).add_message(MESSAGE2); + assert_eq!(request.messages.len(), 2); + assert_eq!(request.messages[0], MESSAGE); + assert_eq!(request.messages[1], MESSAGE2); + } + + #[test] + fn test_extend_messages() { + let mut request = Request::default(); + request = request.extend_messages(MESSAGES); + assert_eq!(request.messages, MESSAGES); + } + + #[test] + fn test_set_max_tokens() { + let max_tokens = NonZeroU16::new(1024).unwrap(); + let request = Request::default().max_tokens(max_tokens); + assert_eq!(request.max_tokens, max_tokens); + } + + #[test] + fn test_set_metadata() { + let metadata = vec![("key".to_string(), json!("value"))]; + let request = Request::default().metadata(metadata); + assert_eq!(request.metadata.get("key").unwrap(), "value"); + } + + #[test] + fn test_try_metadata() { + let request = Request::default() + .try_metadata([("key", "value"), ("key2", "value2")]) + .unwrap(); + assert_eq!(request.metadata.get("key").unwrap(), "value"); + assert_eq!(request.metadata.get("key2").unwrap(), "value2"); + } + + #[test] + fn test_insert_metadata() { + let request = + Request::default().insert_metadata("key", "value").unwrap(); + assert_eq!(request.metadata.get("key").unwrap(), "value"); + } + + #[test] + fn test_set_stop_sequences() { + let request = Request::default().stop_sequences(STOP_SEQUENCES); + assert_eq!(request.stop_sequences.unwrap(), STOP_SEQUENCES); + } + + #[test] + fn test_add_stop_sequence() { + let mut request = Request::default(); + request = request.stop_sequence(STOP_SEQUENCES[0]); + assert_eq!(request.stop_sequences.as_ref().unwrap().len(), 1); + assert_eq!(request.stop_sequences.unwrap()[0], STOP_SEQUENCES[0]); + } + + #[test] + fn test_extend_stop_sequences() { + let mut request = Request::default(); + request = request.extend_stop_sequences(STOP_SEQUENCES); + assert_eq!(request.stop_sequences.unwrap().len(), 2); + } + + #[test] + fn test_set_system() { + let request = Request::default().system("system"); + assert_eq!(request.system.unwrap().to_string(), "system"); + } + + // End of GitHub Copilot tests. + + #[test] + fn test_add_system_block() { + // Test with a system prompt. The call to cache should affect the final + // Block in the system prompt. + let request = Request::default() + .add_system_block("Do this.") // Will add a system Content block + .add_system_block("And then do this."); + + assert_eq!( + request.system.as_ref().unwrap().to_string(), + "Do this.\n\nAnd then do this." + ); + } + + #[test] + #[cfg(feature = "prompt-caching")] + fn test_cache() { + // Test with no system prompt or messages that the call to cache affects + // the tools. + let request = Request::default().add_tool(Tool { + name: "ping".into(), + description: "Ping a server.".into(), + input_schema: json!({}), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }); + + assert!(!request.tools.as_ref().unwrap().last().unwrap().is_cached()); + + let mut request = request.cache(); + + assert!(request.tools.as_ref().unwrap().last().unwrap().is_cached()); + + // remove the cache breakpoint + // TODO: add an un_cache method? set_cache? + request + .tools + .as_mut() + .unwrap() + .last_mut() + .unwrap() + .cache_control = None; + + // Test with a system prompt. The call to cache should affect the final + // Block in the system prompt. + let request = request + .add_system_block("Do this.") // Will add a system Content block + .add_system_block("And then do this.") + .cache(); + + assert!(request.system.as_ref().unwrap().last().unwrap().is_cached()); + // ensure the tools are not affected + assert!(!request.tools.as_ref().unwrap().last().unwrap().is_cached()); + + // Test with messages. The call to cache should affect the last message. + let request = request + .add_message(Message { + role: Role::User, + content: Content::text("Hello"), + }) + .add_message(Message { + role: Role::Assistant, + content: Content::text("Hi"), + }) + .cache(); + + // The first message should still be a single part string. + assert!(request.messages.first().unwrap().content.last().is_none()); + + // By now the final part should be a multi part string, since only + // Block has `cache_control` + assert!(request + .messages + .last() + .unwrap() + .content + .last() + .unwrap() + .is_cached()); + } + + #[test] + fn test_tools() { + // A tool can be added from a json object. This is fallible. It must + // deserialize into a Tool. + let json_tool = json!({ + "name": "ping2", + "description": "Ping a server. Part deux.", + "input_schema": { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "The host to ping." + } + }, + "required": ["host"] + } + }); + + let schema = json_tool["input_schema"].clone(); + + // A tool can be created from a Tool itself. This is infallible, however + // the API might reject the request if the tool is invalid. There is + // currently no schema validation in this crate. + let tool = Tool { + name: "ping".into(), + description: "Ping a server.".into(), + input_schema: schema.clone(), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }; + + let request = Request::default() + .tools([tool]) + .try_add_tool(json_tool) + .unwrap(); + + assert_eq!(request.tools.as_ref().unwrap().len(), 2); + assert_eq!(request.tools.as_ref().unwrap()[0].name, "ping"); + assert_eq!(request.tools.as_ref().unwrap()[1].name, "ping2"); + assert_eq!( + request.tools.as_ref().unwrap()[0].description, + "Ping a server." + ); + assert_eq!( + request.tools.as_ref().unwrap()[1].description, + "Ping a server. Part deux." + ); + assert_eq!(request.tools.as_ref().unwrap()[0].input_schema, schema); + } +} diff --git a/src/request/message.rs b/src/request/message.rs index 66f3377..a9cd478 100644 --- a/src/request/message.rs +++ b/src/request/message.rs @@ -4,14 +4,21 @@ //! [`response::Message`]: crate::response::Message //! [`request::Message`]: crate::request::Message +use std::borrow::Cow; + use base64::engine::{general_purpose, Engine as _}; use serde::{Deserialize, Serialize}; -use crate::response; +use crate::{ + response, + stream::{ContentMismatch, Delta, DeltaError}, + tool, +}; /// Role of the [`Message`] author. -#[derive(Clone, Copy, Debug, Serialize, Deserialize, derive_more::Display)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] pub enum Role { /// From the user. User, @@ -19,24 +26,48 @@ pub enum Role { Assistant, } +impl Role { + /// Get the string representation of the role. + pub const fn as_str(&self) -> &'static str { + match self { + Self::User => "User", + Self::Assistant => "Assistant", + } + } +} + +impl std::fmt::Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + /// A message in a [`Request`]. See [`response::Message`] for the version with /// additional metadata. /// -/// A message is rendered as markdown with a [heading] indicating the [`Role`] -/// of the author. [`Image`]s are supported and will be rendered as markdown -/// images with embedded base64 data. [`Content`] [`Part`]s are separated by -/// [`Content::SEP`]. +/// A message is [`Display`]ed as markdown with a [heading] indicating the +/// [`Role`] of the author. [`Image`]s are supported and will be rendered as +/// markdown images with embedded base64 data. /// +/// [`Display`]: std::fmt::Display /// [`Request`]: crate::Request /// [`response::Message`]: crate::response::Message /// [heading]: Message::HEADING -#[derive(Debug, Serialize, Deserialize, derive_more::Display)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -#[display("{}{}{}{}", Self::HEADING, role, Content::SEP, content)] +#[cfg_attr( + not(feature = "markdown"), + derive(derive_more::Display), + display("{}{}{}{}", Self::HEADING, role, Content::SEP, content) +)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] pub struct Message { - /// Who is providing the content. + /// Who is the message from. pub role: Role, - /// The content of the message. + /// The [`Content`] of the message as [one] or [more] [`Block`]s. + /// + /// [one]: Content::SinglePart + /// [more]: Content::MultiPart pub content: Content, } @@ -44,7 +75,25 @@ impl Message { /// Heading for the message when rendered as markdown using [`Display`]. /// /// [`Display`]: std::fmt::Display + #[cfg(not(feature = "markdown"))] pub const HEADING: &'static str = "### "; + /// Heading for the message when rendered as markdown using markdown methods + /// as well as [`Display`]. + /// + /// [`Display`]: std::fmt::Display + #[cfg(feature = "markdown")] + pub const HEADING: pulldown_cmark::Tag<'static> = + pulldown_cmark::Tag::Heading { + level: pulldown_cmark::HeadingLevel::H3, + id: None, + classes: vec![], + attrs: vec![], + }; + + /// Returns the number of [`Content`] [`Block`]s in the message. + pub fn len(&self) -> usize { + self.content.len() + } } impl From for Message { @@ -53,27 +102,92 @@ impl From for Message { } } +impl From<(Role, Cow<'static, str>)> for Message { + fn from((role, content): (Role, Cow<'static, str>)) -> Self { + Self { + role, + content: Content::SinglePart(content), + } + } +} + +impl From<(Role, &'static str)> for Message { + fn from((role, content): (Role, &'static str)) -> Self { + Self { + role, + content: Content::SinglePart(Cow::Borrowed(content)), + } + } +} + +#[cfg(feature = "markdown")] +impl crate::markdown::ToMarkdown for Message { + /// Returns an iterator over the text as [`pulldown_cmark::Event`]s using + /// custom [`Options`]. This is [`Content`] markdown plus a heading for the + /// [`Role`]. + /// + /// [`Options`]: crate::markdown::Options + fn markdown_events_custom<'a>( + &'a self, + options: &'a crate::markdown::Options, + ) -> Box> + 'a> { + use pulldown_cmark::Event; + + let content = self.content.markdown_events_custom(options); + let role = match self.content.last() { + Some(Block::ToolResult { is_error, .. }) => { + if *is_error { + "Error" + } else { + "Tool" + } + } + _ => self.role.as_str(), + }; + let heading = [ + Event::Start(Self::HEADING), + Event::Text(role.into()), + Event::End(Self::HEADING.to_end()), + ]; + + Box::new(heading.into_iter().chain(content)) + } +} + +#[cfg(feature = "markdown")] +impl std::fmt::Display for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use crate::markdown::ToMarkdown; + + self.write_markdown(f) + } +} + /// Content of a [`Message`]. -#[derive(Debug, Serialize, Deserialize, derive_more::From)] +#[derive( + Debug, Serialize, Deserialize, derive_more::From, derive_more::IsVariant, +)] #[serde(rename_all = "snake_case")] #[serde(untagged)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] pub enum Content { /// Single part text-only content. - SinglePart(String), - /// Multiple content [`Part`]s. - MultiPart(Vec), + SinglePart(Cow<'static, str>), + /// Multiple content [`Block`]s. + MultiPart(Vec), } impl Content { - /// Length of the visible content in bytes, not including metadata like the - /// [`MediaType`] for images, the [`CacheControl`] for text, [`Tool`] - /// calls, results, or any separators or headers. - /// - /// [`Tool`]: crate::tool::Tool + /// Const constructor for static text content. + pub const fn text(text: &'static str) -> Self { + Self::SinglePart(Cow::Borrowed(text)) + } + + /// Returns the number of [`Block`]s in `self`. pub fn len(&self) -> usize { match self { - Self::SinglePart(string) => string.len(), - Self::MultiPart(parts) => parts.iter().map(Part::len).sum(), + Self::SinglePart(_) => 1, + Self::MultiPart(parts) => parts.len(), } } @@ -81,8 +195,118 @@ impl Content { pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Unwrap [`Content::SinglePart`] as a [`Block::Text`]. This will panic if + /// `self` is [`MultiPart`]. + /// + /// [`SinglePart`]: Content::SinglePart + /// [`MultiPart`]: Content::MultiPart + /// + /// # Panics + /// - If the content is [`MultiPart`]. + pub fn unwrap_single_part(self) -> Block { + match self { + #[cfg(feature = "prompt-caching")] + Self::SinglePart(text) => Block::Text { + text, + cache_control: None, + }, + #[cfg(not(feature = "prompt-caching"))] + Self::SinglePart(text) => Block::Text { text }, + Self::MultiPart(_) => { + panic!("Content is MultiPart, not SinglePart"); + } + } + } + + /// Add a [`Block`] to the [`Content`]. If the [`Content`] is a + /// [`SinglePart`], it will be converted to a [`MultiPart`]. Returns the + /// index of the added [`Block`]. + /// + /// [`SinglePart`]: Content::SinglePart + /// [`MultiPart`]: Content::MultiPart + pub fn push

(&mut self, part: P) -> usize + where + P: Into, + { + // If there is a SinglePart message, convert it to a MultiPart message. + if self.is_single_part() { + // the old switcheroo + let mut old = Content::MultiPart(vec![]); + std::mem::swap(self, &mut old); + // This can never loop because we ensure self is a MultiPart which + // will skip this block. + self.push(old.unwrap_single_part()); + } + + if let Content::MultiPart(parts) = self { + parts.push(part.into()); + + parts.len() - 1 + } else { + unreachable!("Content is not MultiPart"); + } + } + + /// Add a cache breakpoint to the final [`Block`]. If the [`Content`] is + /// [`SinglePart`], it will be converted to [`MultiPart`] first. + /// + /// [`SinglePart`]: Content::SinglePart + /// [`MultiPart`]: Content::MultiPart + #[cfg(feature = "prompt-caching")] + pub fn cache(&mut self) { + if self.is_single_part() { + let mut old = Content::MultiPart(vec![]); + std::mem::swap(self, &mut old); + self.push(old.unwrap_single_part()); + } + + if let Content::MultiPart(parts) = self { + if let Some(block) = parts.last_mut() { + block.cache(); + } + } + } + + /// Get the last [`Block`] in the [`Content`]. Returns [`None`] if the + /// [`Content`] is empty. + pub fn last(&self) -> Option<&Block> { + match self { + Self::SinglePart(_) => None, + Self::MultiPart(parts) => parts.last(), + } + } } +#[cfg(feature = "markdown")] +impl crate::markdown::ToMarkdown for Content { + /// Returns an iterator over the text as [`pulldown_cmark::Event`]s using + /// custom [`Options`]. + /// + /// [`Options`]: crate::markdown::Options + #[cfg(feature = "markdown")] + fn markdown_events_custom<'a>( + &'a self, + options: &'a crate::markdown::Options, + ) -> Box> + 'a> { + use pulldown_cmark::Event; + + let it: Box> + 'a> = match self { + Self::SinglePart(string) => { + Box::new(pulldown_cmark::Parser::new(string)) + } + Self::MultiPart(parts) => Box::new( + parts + .iter() + .flat_map(move |part| part.markdown_events_custom(options)), + ), + }; + + it + } +} + +#[cfg(not(feature = "markdown"))] impl std::fmt::Display for Content { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -103,29 +327,53 @@ impl std::fmt::Display for Content { } } +#[cfg(feature = "markdown")] +impl std::fmt::Display for Content { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use crate::markdown::ToMarkdown; + + self.write_markdown(f) + } +} + impl Content { /// Separator for multi-part content. + #[cfg(not(feature = "markdown"))] pub const SEP: &'static str = "\n\n"; } -/// A [`Content`] [`Part`] of a [`Message`], either [`Text`] or [`Image`]. -/// -/// [`Text`]: Part::Text -#[derive(Debug, Serialize, Deserialize, derive_more::Display)] +impl From<&'static str> for Content { + fn from(s: &'static str) -> Self { + Self::SinglePart(s.into()) + } +} + +impl From for Content { + fn from(s: String) -> Self { + Self::SinglePart(s.into()) + } +} + +impl From for Content { + fn from(block: Block) -> Self { + Self::MultiPart(vec![block]) + } +} + +/// A [`Content`] [`Block`] of a [`Message`]. +#[derive(Debug, Serialize, Deserialize)] +#[cfg_attr(not(feature = "markdown"), derive(derive_more::Display))] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] -pub enum Part { +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +pub enum Block { /// Text content. #[serde(alias = "text_delta")] - #[display("{}", text)] + #[cfg_attr(not(feature = "markdown"), display("{text}"))] Text { /// The actual text content. - text: String, - /// Use prompt caching. The [`text`] needs to be at least 1024 tokens - /// for Sonnet 3.5 and Opus 3.0 or 2048 for Haiku 3.0 or this will be - /// ignored. - /// - /// [`text`]: Part::Text::text + text: Cow<'static, str>, + /// Use prompt caching. See [`Block::cache`] for more information. #[cfg(feature = "prompt-caching")] #[serde(skip_serializing_if = "Option::is_none")] cache_control: Option, @@ -135,87 +383,284 @@ pub enum Part { #[serde(rename = "source")] /// An base64 encoded image. image: Image, + /// Use prompt caching. See [`Block::cache`] for more information. + #[cfg(feature = "prompt-caching")] + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, }, /// [`Tool`] call. This should only be used with the [`Assistant`] role. /// /// [`Assistant`]: Role::Assistant /// [`Tool`]: crate::Tool - #[display("")] + // Default display is to hide this from the user. + #[cfg_attr(not(feature = "markdown"), display(""))] ToolUse { - /// Unique Id for this tool call. - id: String, - /// Name of the tool. - name: String, - /// Input for the tool. - input: serde_json::Value, + /// Tool use input. + #[serde(flatten)] + call: crate::tool::Use, }, /// Result of a [`Tool`] call. This should only be used with the [`User`] /// role. /// /// [`User`]: Role::User /// [`Tool`]: crate::Tool - #[display("")] + #[cfg_attr(not(feature = "markdown"), display(""))] ToolResult { /// Unique Id for this tool call. tool_use_id: String, /// Output of the tool. - content: serde_json::Value, + content: Content, /// Whether the tool call result was an error. is_error: bool, + /// Use prompt caching. See [`Block::cache`] for more information. + #[cfg(feature = "prompt-caching")] + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, }, } -impl Part { - /// Length of text or image data in bytes not including metadata like - /// the [`MediaType`] for images or the [`CacheControl`] for text. - pub fn len(&self) -> usize { +#[cfg(feature = "markdown")] +impl std::fmt::Display for Block { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use crate::markdown::ToMarkdown; + + self.write_markdown(f) + } +} + +impl Block { + /// Const constructor for static text content. + pub const fn new_text(text: &'static str) -> Self { + Self::Text { + text: Cow::Borrowed(text), + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + } + + /// Merge [`Delta`]s into a [`Block`]. The types must be compatible or this + /// will return a [`ContentMismatch`] error. + pub fn merge_deltas(&mut self, deltas: Ds) -> Result<(), DeltaError> + where + Ds: IntoIterator, + { + let mut it = deltas.into_iter(); + + // Get the first delta so we can try to fold the rest into it. + let acc: Delta = match it.next() { + Some(delta) => delta, + // Empty iterator, nothing to merge. + None => return Ok(()), + }; + + // Merge the rest of the deltas into the first one. (there isn't a + // `try_reduce` method yet) + let acc: Delta = it.try_fold(acc, |acc, delta| acc.merge(delta))?; + + // Apply the merged delta to the block. + match (self, acc) { + (Block::Text { text, .. }, Delta::Text { text: delta }) => { + text.to_mut().push_str(&delta); + } + ( + Block::ToolUse { + call: tool::Use { input, .. }, + }, + Delta::Json { partial_json }, + ) => { + *input = serde_json::from_str(&partial_json) + .map_err(|e| e.to_string())?; + } + (this, acc) => { + let variant_name = match this { + Block::Text { .. } => stringify!(Block::Text), + Block::ToolUse { .. } => stringify!(Block::ToolUse), + Block::ToolResult { .. } => stringify!(Block::ToolResult), + Block::Image { .. } => stringify!(Block::Image), + }; + + return Err(ContentMismatch { + from: acc, + to: variant_name, + } + .into()); + } + } + + Ok(()) + } + + /// Create a cache breakpoint at this block. For this to have any effect, + /// the full prefix before this point needs to be at least 1024 tokens for + /// [`Sonnet35`] and [`Opus30`] or 2048 tokens for [`Haiku30`]. + /// + /// Note: The caching feature is in beta, so this is likely to change. + #[cfg(feature = "prompt-caching")] + pub fn cache(&mut self) { + use crate::tool; + match self { - Self::Text { text, .. } => text.len(), - Self::Image { image } => match image { - Image::Base64 { data, .. } => data.len(), - }, - _ => 0, + Self::Text { cache_control, .. } + | Self::Image { cache_control, .. } + | Self::ToolUse { + call: tool::Use { cache_control, .. }, + } + | Self::ToolResult { cache_control, .. } => { + *cache_control = Some(CacheControl::Ephemeral); + } } } - /// Returns true if the part is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 + /// Returns true if the block has a `cache_control` breakpoint. + #[cfg(feature = "prompt-caching")] + pub const fn is_cached(&self) -> bool { + use crate::tool; + + match self { + Self::Text { cache_control, .. } + | Self::Image { cache_control, .. } + | Self::ToolUse { + call: tool::Use { cache_control, .. }, + } + | Self::ToolResult { cache_control, .. } => cache_control.is_some(), + } + } + + /// Returns the [`tool::Use`] if this is a [`Block::ToolUse`]. See also + /// [`response::Message::tool_use`]. + pub fn tool_use(&self) -> Option<&crate::tool::Use> { + match self { + Self::ToolUse { call, .. } => Some(call), + _ => None, + } } } -impl From<&str> for Part { - fn from(text: &str) -> Self { +#[cfg(feature = "markdown")] +impl crate::markdown::ToMarkdown for Block { + /// Returns an iterator over the text as [`pulldown_cmark::Event`]s using + /// custom [`Options`]. + /// + /// [`Options`]: crate::markdown::Options + #[cfg(feature = "markdown")] + fn markdown_events_custom<'a>( + &'a self, + options: &crate::markdown::Options, + ) -> Box> + 'a> { + use pulldown_cmark::{CodeBlockKind, Event, Tag, TagEnd}; + + let it: Box> + 'a> = match self { + Self::Text { text, .. } => { + // We'll parse the inner text as markdown. + Box::new(pulldown_cmark::Parser::new_ext(text, options.inner)) + } + + Block::Image { image, .. } => { + // We use Event::Text for images because they are rendered as + // markdown images with embedded base64 data. + Box::new( + Some(Event::Text(image.to_string().into())).into_iter(), + ) + } + Block::ToolUse { .. } => { + if options.tool_use { + Box::new( + [ + Event::Start(Tag::CodeBlock( + CodeBlockKind::Fenced("json".into()), + )), + Event::Text( + serde_json::to_string(self).unwrap().into(), + ), + Event::End(TagEnd::CodeBlock), + ] + .into_iter(), + ) + } else { + Box::new(std::iter::empty()) + } + } + Block::ToolResult { .. } => { + if options.tool_results { + Box::new( + [ + Event::Start(Tag::CodeBlock( + CodeBlockKind::Fenced("json".into()), + )), + Event::Text( + serde_json::to_string(self).unwrap().into(), + ), + Event::End(TagEnd::CodeBlock), + ] + .into_iter(), + ) + } else { + Box::new(std::iter::empty()) + } + } + }; + + it + } +} + +impl From<&'static str> for Block { + fn from(text: &'static str) -> Self { Self::Text { - text: text.to_string(), + text: text.into(), #[cfg(feature = "prompt-caching")] cache_control: None, } } } -impl From for Part { +impl From for Block { fn from(text: String) -> Self { Self::Text { - text, + text: text.into(), #[cfg(feature = "prompt-caching")] cache_control: None, } } } -impl From for Part { +impl From for Block { fn from(image: Image) -> Self { - Self::Image { image } + Self::Image { + image, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + } +} + +#[cfg(feature = "png")] +impl From for Block { + fn from(image: image::RgbaImage) -> Self { + Image::encode(MediaType::Png, image) + // Unwrap can never panic unless the PNG encoding fails. + .unwrap_or_else(|e| { + eprintln!("Error encoding image: {}", e); + Image::from_parts(MediaType::Png, String::new()) + }) + .into() + } +} + +#[cfg(feature = "png")] +impl From for Block { + fn from(image: image::DynamicImage) -> Self { + image.to_rgba8().into() } } /// Cache control for prompt caching. #[cfg(feature = "prompt-caching")] -#[derive(Default, Debug, Serialize, Deserialize)] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] #[serde(tag = "type")] pub enum CacheControl { - /// Ephemeral + /// Caches for 5 minutes. #[default] Ephemeral, } @@ -224,6 +669,7 @@ pub enum CacheControl { /// /// [`MultiPart`]: Content::MultiPart #[derive(Debug, Serialize, Deserialize, derive_more::Display)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] pub enum Image { @@ -316,6 +762,7 @@ impl TryInto for Image { /// Encoding format for [`Image`]s. #[derive(Clone, Copy, Debug, Serialize, Deserialize)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] #[serde(rename_all = "snake_case")] #[allow(missing_docs)] pub enum MediaType { @@ -412,6 +859,242 @@ mod tests { fn deserialize_message_single() { let message: Message = serde_json::from_str(MESSAGE_JSON_SINGLE).unwrap(); + // FIXME: This is really testing the Display impl. There should be a + // separate test for that. assert_eq!(message.to_string(), "### User\n\nHello, world"); } + + #[test] + #[cfg(feature = "markdown")] + fn test_merge_deltas() { + use crate::markdown::ToMarkdown; + + let mut block: Block = "Hello, world!".into(); + + // this is allowed + block.merge_deltas([]).unwrap(); + + let deltas = [ + Delta::Text { + text: ", how are you?".to_string(), + }, + Delta::Text { + text: " I'm fine.".to_string(), + }, + ]; + + block.merge_deltas(deltas).unwrap(); + + assert_eq!(block.to_string(), "Hello, world!, how are you? I'm fine."); + + // with tool use + let mut block: Block = Block::ToolUse { + call: tool::Use { + id: "tool_123".into(), + name: "tool".into(), + input: serde_json::json!({}), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }, + }; + + // partial json to apply to the input portion + let deltas = [Delta::Json { + partial_json: r#"{"key": "value"}"#.to_string(), + }]; + + block.merge_deltas(deltas).unwrap(); + + // by default tool use is hidden + let opts = crate::markdown::Options::default().with_tool_use(); + + let markdown = block.markdown_custom(&opts); + + assert_eq!( + markdown.as_ref(), + "\n````json\n{\"type\":\"tool_use\",\"id\":\"tool_123\",\"name\":\"tool\",\"input\":{\"key\":\"value\"}}\n````" + ); + + // content mismatch + let deltas = [Delta::Json { + partial_json: "blabla".to_string(), + }]; + let mut block = Block::Text { + text: "Hello, world!".into(), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }; + + let err = block.merge_deltas(deltas).unwrap_err(); + assert_eq!( + err.to_string(), + "Cannot apply delta because: `Delta::Json { partial_json: \"blabla\" }` canot be applied to `Block::Text`." + ); + } + + #[test] + fn test_message_len() { + let mut message = Message { + role: Role::User, + content: Content::SinglePart("Hello, world!".into()), + }; + + assert_eq!(message.len(), 1); + + message.content.push("How are you?"); + + assert_eq!(message.len(), 2); + } + + #[test] + fn test_from_response_message() { + let response = response::Message { + message: Message { + role: Role::User, + content: Content::text("Hello, world!"), + }, + id: "msg_123".into(), + model: crate::Model::Sonnet35, + stop_reason: None, + stop_sequence: None, + usage: Default::default(), + }; + + let message: Message = response.into(); + + assert_eq!(message.to_string(), "### User\n\nHello, world!"); + } + + #[test] + fn test_from_role_cow() { + let text: Cow<'static, str> = "Hello, world!".into(); + let message: Message = (Role::User, text).into(); + + assert_eq!(message.to_string(), "### User\n\nHello, world!"); + } + + #[test] + fn test_from_role_str() { + let message: Message = (Role::User, "Hello, world!").into(); + + assert_eq!(message.to_string(), "### User\n\nHello, world!"); + } + + #[test] + fn test_content_is_empty() { + let mut content = Content::SinglePart("Hello, world!".into()); + assert!(!content.is_empty()); + + content = Content::MultiPart(vec![]); + assert!(content.is_empty()); + } + + #[test] + fn tests_content_unwrap_single_part() { + let content = Content::SinglePart("Hello, world!".into()); + assert_eq!(content.unwrap_single_part().to_string(), "Hello, world!"); + } + + #[test] + #[should_panic] + fn test_content_unwrap_single_part_panics() { + let content = Content::MultiPart(vec![]); + content.unwrap_single_part(); + } + + #[test] + fn test_content_from_string() { + let content: Content = "Hello, world!".to_string().into(); + assert_eq!(content.to_string(), "Hello, world!"); + } + + #[test] + fn test_content_from_block() { + let content: Content = Block::new_text("Hello, world!").into(); + assert_eq!(content.to_string(), "Hello, world!"); + } + + #[test] + fn test_merge_deltas_error() { + let mut block: Block = "Hello, world!".into(); + + let deltas = [Delta::Json { + partial_json: "blabla".to_string(), + }]; + + let err = block.merge_deltas(deltas).unwrap_err(); + + assert!(matches!(err, DeltaError::ContentMismatch { .. })); + } + + #[test] + #[cfg(feature = "markdown")] + fn test_message_markdown() { + use crate::markdown::ToMarkdown; + + // test user heading, single part + let message = Message { + role: Role::User, + content: Content::SinglePart("Hello, world!".into()), + }; + + let opts = crate::markdown::Options::default() + .with_tool_use() + .with_tool_results(); + + assert_eq!( + message.markdown_custom(&opts).to_string(), + "### User\n\nHello, world!" + ); + + // test assistant heading, multi part + let message = Message { + role: Role::Assistant, + content: Content::MultiPart(vec![ + "Hello, world!".into(), + "How are you?".into(), + ]), + }; + + assert_eq!( + message.markdown_custom(&opts).to_string(), + "### Assistant\n\nHello, world!\n\nHow are you?" + ); + + // Test tool result (success) + let message = Message { + role: Role::User, + content: Block::ToolResult { + tool_use_id: "tool_123".into(), + content: Content::SinglePart("Hello, world!".into()), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(), + }; + + assert_eq!( + message.markdown_custom(&opts).to_string(), + "### Tool\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"tool_123\",\"content\":\"Hello, world!\",\"is_error\":false}\n````" + ); + + // Test tool result (error) + let message = Message { + role: Role::User, + content: Block::ToolResult { + tool_use_id: "tool_123".into(), + content: Content::SinglePart("Hello, world!".into()), + is_error: true, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(), + }; + + assert_eq!( + message.markdown_custom(&opts).to_string(), + "### Error\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"tool_123\",\"content\":\"Hello, world!\",\"is_error\":true}\n````" + ); + } } diff --git a/src/response.rs b/src/response.rs index 47eedf0..f500f5d 100644 --- a/src/response.rs +++ b/src/response.rs @@ -141,7 +141,7 @@ mod tests { #[test] fn deserialize_response_message() { let message: Message = serde_json::from_str(RESPONSE_JSON).unwrap(); - assert_eq!(message.message.content.len(), 22); + assert_eq!(message.message.content.len(), 1); assert_eq!(message.id, "msg_013Zva2CMHLNnXjNJJKqJ2EF"); assert_eq!(message.model, crate::Model::Sonnet35); assert!(matches!(message.stop_reason, Some(StopReason::EndTurn))); diff --git a/src/response/message.rs b/src/response/message.rs index 725fd95..b87907c 100644 --- a/src/response/message.rs +++ b/src/response/message.rs @@ -1,8 +1,9 @@ -use crate::{request, Model}; +use crate::{request, stream::MessageDelta, Model}; use serde::{Deserialize, Serialize}; /// A [`request::Message`] with additional response metadata. #[derive(Debug, Serialize, Deserialize, derive_more::Display)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] #[display("{}", message)] pub struct Message { /// Unique `id` for the message. @@ -23,8 +24,36 @@ pub struct Message { pub usage: Usage, } +impl Message { + /// Apply a [`MessageDelta`] with metadata to the message. + pub fn apply_delta(&mut self, delta: MessageDelta) { + self.stop_reason = delta.stop_reason; + self.stop_sequence = delta.stop_sequence; + if let Some(usage) = delta.usage { + self.usage = usage; + } + } + + /// Get the [`tool::Use`] from the message if the [`StopReason`] was + /// [`StopReason::ToolUse`] and the final message [`Content`] [`Block`] is + /// [`ToolUse`]. + /// + /// [`Content`]: crate::request::message::Content + /// [`Block`]: crate::request::message::Block + /// [`tool::Use`]: crate::tool::Use + /// [`ToolUse`]: crate::request::message::Block::ToolUse + pub fn tool_use(&self) -> Option<&crate::tool::Use> { + if !matches!(self.stop_reason, Some(StopReason::ToolUse)) { + return None; + } + + self.message.content.last()?.tool_use() + } +} + /// Reason the model stopped generating tokens. #[derive(Debug, Serialize, Deserialize)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] #[serde(rename_all = "snake_case")] pub enum StopReason { /// The model reached a natural stopping point. @@ -39,7 +68,8 @@ pub enum StopReason { /// Usage statistics from the API. This is used in multiple contexts, not just /// for messages. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] pub struct Usage { /// Number of input tokens used. pub input_tokens: u64, diff --git a/src/stream.rs b/src/stream.rs index 0f07764..ef37073 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,8 +4,13 @@ use futures::StreamExt; use serde::{Deserialize, Serialize}; use std::pin::Pin; +#[allow(unused_imports)] // `Content`, `request` Used in docs. use crate::{ client::AnthropicError, + request::{ + self, + message::{Block, Content}, + }, response::{self, StopReason, Usage}, }; @@ -17,33 +22,33 @@ use crate::{ pub enum Event { /// Periodic ping. Ping, - /// [`response::Message`] with empty content. The deltas arrive in - /// [`ContentBlock`]s. + /// [`response::Message`] with empty content. [`MessageDelta`] and + /// [`Content`] [`Delta`]s must be applied to this message. MessageStart { /// The message. message: response::Message, }, - /// Content block with empty content. + /// [`Content`] [`Block`] with empty content. ContentBlockStart { - /// Index of the content block. + /// Index of the [`Content`] [`Block`] in [`request::Message::content`]. index: usize, /// Empty content block. - content_block: ContentBlock, + content_block: Block, }, /// Content block delta. ContentBlockDelta { - /// Index of the content block. + /// Index of the [`Content`] [`Block`] in [`request::Message::content`]. index: usize, /// Delta to apply to the content block. - delta: ContentBlock, + delta: Delta, }, /// Content block end. ContentBlockStop { - /// Index of the content block. + /// Index of the [`Content`] [`Block`] in [`request::Message::content`]. index: usize, }, - /// Message delta. Confusingly this does not contain message content rather - /// metadata about the message in progress. + /// [`MessageDelta`]. Contains metadata, not [`Content`] [`Delta`]s. Apply + /// to the [`response::Message`]. MessageDelta { /// Delta to apply to the [`response::Message`]. delta: MessageDelta, @@ -66,80 +71,100 @@ enum ApiResult { Error { error: AnthropicError }, } -/// A content block or delta. This can be [`Text`], [`Json`], or [`Tool`] use. +/// [`Text`] or [`Json`] to be applied to a [`Block::Text`] or +/// [`Block::ToolUse`] [`Content`] [`Block`]. /// -/// [`Text`]: ContentBlock::Text -/// [`Json`]: ContentBlock::Json -/// [`Tool`]: ContentBlock::Tool +/// [`Text`]: Delta::Text +/// [`Json`]: Delta::Json #[derive(Debug, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "snake_case", tag = "type")] -pub enum ContentBlock { - /// Text content. +pub enum Delta { + /// Text delta for a [`Text`] [`Content`] [`Block`]. + /// + /// [`Text`]: Block::Text #[serde(alias = "text_delta")] Text { /// The text content. text: String, }, - /// JSON delta. + /// JSON delta for the input field of a [`ToolUse`] [`Content`] [`Block`]. + /// + /// [`ToolUse`]: Block::ToolUse #[serde(rename = "input_json_delta")] Json { /// The JSON delta. partial_json: String, }, - /// Tool use. - #[serde(rename = "tool_use")] - Tool { - /// ID of the request. - id: String, - /// Name of the tool. - name: String, - /// Input to the tool. - input: serde_json::Value, - }, } -/// Error when applying a [`ContentBlock`] delta to a target [`ContentBlock`]. +/// Error when applying a [`Delta`] to a [`Content`] [`Block`] and the types do +/// not match. #[derive(Serialize, thiserror::Error, Debug)] -#[error("Cannot apply delta {from:?} to {to:?}.")] -pub struct ContentMismatch<'a> { +#[error("`Delta::{from:?}` canot be applied to `{to}`.")] +pub struct ContentMismatch { /// The content block that failed to apply. - from: ContentBlock, - /// The target content block. - to: &'a ContentBlock, + pub from: Delta, + /// The target [`Content`]. + pub to: &'static str, } -impl ContentBlock { - /// Apply a [`ContentBlock`] delta to self. - pub fn append( - &mut self, - delta: ContentBlock, - ) -> Result<(), ContentMismatch> { - match (self, delta) { - ( - ContentBlock::Text { text }, - ContentBlock::Text { text: delta }, - ) => { +/// Error when applying a [`Delta`] to a [`Content`] [`Block`] and the index is +/// out of bounds. +#[derive(Serialize, thiserror::Error, Debug)] +#[error("Index {index} out of bounds. Max index is {max}.")] +pub struct OutOfBounds { + /// The index that was out of bounds. + pub index: usize, + /// The maximum index. + pub max: usize, +} + +/// Error when applying a [`Delta`]. +#[derive(Serialize, thiserror::Error, Debug, derive_more::From)] +#[allow(missing_docs)] +pub enum DeltaError { + #[error("Cannot apply delta because: {error}")] + ContentMismatch { error: ContentMismatch }, + #[error("Cannot apply delta because: {error}")] + OutOfBounds { error: OutOfBounds }, + #[error( + "Cannot apply delta because deserialization failed because: {error}" + )] + Parse { error: String }, +} + +impl Delta { + /// Merge another [`Delta`] onto the end of `self`. + pub fn merge(mut self, delta: Delta) -> Result { + match (&mut self, delta) { + (Delta::Text { text }, Delta::Text { text: delta }) => { text.push_str(&delta); } ( - ContentBlock::Json { partial_json }, - ContentBlock::Json { + Delta::Json { partial_json }, + Delta::Json { partial_json: delta, }, ) => { partial_json.push_str(&delta); } (to, from) => { - return Err(ContentMismatch { from, to }); + return Err(ContentMismatch { + from, + to: match to { + Delta::Text { .. } => stringify!(Delta::Text), + Delta::Json { .. } => stringify!(Delta::Json), + }, + }); } } - Ok(()) + Ok(self) } } /// Metadata about a message in progress. This does not contain actual text -/// deltas. That's the [`ContentBlock`] in [`Event::ContentBlockDelta`]. +/// deltas. That's the [`Delta`] in [`Event::ContentBlockDelta`]. #[derive(Debug, Serialize, Deserialize)] pub struct MessageDelta { /// Stop reason. @@ -245,9 +270,7 @@ impl Stream { /// Filter out everything but [`Event::ContentBlockDelta`]. This can include /// text, JSON, and tool use. - pub fn deltas( - self, - ) -> impl futures::Stream> { + pub fn deltas(self) -> impl futures::Stream> { self.inner.filter_map(|result| async move { match result { Ok(Event::ContentBlockDelta { delta, .. }) => Some(Ok(delta)), @@ -260,7 +283,7 @@ impl Stream { pub fn text(self) -> impl futures::Stream> { self.deltas().filter_map(|result| async move { match result { - Ok(ContentBlock::Text { text }) => Some(Ok(text)), + Ok(Delta::Text { text }) => Some(Ok(text)), _ => None, } }) @@ -297,10 +320,23 @@ mod tests { content_block, } => { assert_eq!(index, 0); - assert_eq!( - content_block, - ContentBlock::Text { text: "".into() } - ); + #[cfg(feature = "prompt-caching")] + if let Block::Text { + text, + cache_control, + } = content_block + { + assert_eq!(text, ""); + assert!(cache_control.is_none()); + } else { + panic!("Unexpected content block: {:?}", content_block); + } + #[cfg(not(feature = "prompt-caching"))] + if let Block::Text { text } = content_block { + assert_eq!(text, ""); + } else { + panic!("Unexpected content block: {:?}", content_block); + } } _ => panic!("Unexpected event: {:?}", event), } @@ -314,7 +350,7 @@ mod tests { assert_eq!(index, 0); assert_eq!( delta, - ContentBlock::Text { + Delta::Text { text: "Certainly! I".into() } ); @@ -322,4 +358,24 @@ mod tests { _ => panic!("Unexpected event: {:?}", event), } } + + #[test] + fn test_content_block_delta_merge() { + let delta = Delta::Text { + text: "Certainly! I".into(), + } + .merge(Delta::Text { + text: " can".into(), + }) + .unwrap() + .merge(Delta::Text { text: " do".into() }) + .unwrap(); + + assert_eq!( + delta, + Delta::Text { + text: "Certainly! I can do".into() + } + ); + } } diff --git a/src/tool.rs b/src/tool.rs index 44b3586..9a5dac1 100644 --- a/src/tool.rs +++ b/src/tool.rs @@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; /// [`request::Message`]: crate::request::Message #[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case", tag = "type")] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] pub enum Choice { /// Model chooses which tool to use, or no tool at all. Auto, @@ -21,6 +22,7 @@ pub enum Choice { /// A tool a model can use while completing a [`request::Message`]. /// /// [`request::Message`]: crate::request::Message +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] #[derive(Serialize, Deserialize)] pub struct Tool { /// Name of the tool. @@ -34,4 +36,105 @@ pub struct Tool { /// [tool use guide]: /// [JSON Schema]: pub input_schema: serde_json::Value, + /// Set a cache breakpoint at this tool. See [`Request::cache`] notes for + /// more information. + /// + /// [`Request::cache`]: crate::request::Request::cache + #[cfg(feature = "prompt-caching")] + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, +} + +impl Tool { + /// Create a cache breakpoint at this [`Tool`] by setting [`cache_control`] + /// to [`Ephemeral`] See [`Request::cache`] for more information. + /// + /// [`cache_control`]: Self::cache_control + /// [`Ephemeral`]: crate::request::message::CacheControl::Ephemeral + /// [`Request::cache`]: crate::request::Request::cache + #[cfg(feature = "prompt-caching")] + pub fn cache(&mut self) -> &mut Self { + self.cache_control = + Some(crate::request::message::CacheControl::Ephemeral); + self + } + + /// Returns true if the [`Tool`] has a cache breakpoint set (if + /// `cache_control` is [`Some`]). + #[cfg(feature = "prompt-caching")] + pub fn is_cached(&self) -> bool { + self.cache_control.is_some() + } +} + +impl TryFrom for Tool { + type Error = serde_json::Error; + + fn try_from(value: serde_json::Value) -> Result { + serde_json::from_value(value) + } +} + +/// A tool call made by the model. This should be handled and a response sent +/// back in a [`Block::ToolRes`] +#[cfg_attr( + not(feature = "markdown"), + derive(derive_more::Display), + display("\n````json\n{}\n````\n", serde_json::to_string_pretty(self).unwrap()) +)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Use { + /// Unique Id for this tool call. + pub id: String, + /// Name of the tool. + pub name: String, + /// Input for the tool. + pub input: serde_json::Value, + /// Use prompt caching. See [`Block::cache`] for more information. + #[cfg(feature = "prompt-caching")] + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, +} + +impl TryFrom for Use { + type Error = serde_json::Error; + + fn try_from(value: serde_json::Value) -> Result { + serde_json::from_value(value) + } +} + +#[cfg(feature = "markdown")] +impl crate::markdown::ToMarkdown for Use { + fn markdown_events_custom<'a>( + &'a self, + options: &'a crate::markdown::Options, + ) -> Box> + 'a> { + use pulldown_cmark::{CodeBlockKind, Event, Tag, TagEnd}; + + if options.tool_use { + Box::new( + [ + Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced( + "json".into(), + ))), + Event::Text(serde_json::to_string(self).unwrap().into()), + Event::End(TagEnd::CodeBlock), + ] + .into_iter(), + ) + } else { + Box::new(std::iter::empty()) + } + } +} + +#[cfg(feature = "markdown")] +impl std::fmt::Display for Use { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use crate::markdown::ToMarkdown; + + self.write_markdown(f) + } } From dbbb1b4b03d43d60609bb7dbed3c830129d537b2 Mon Sep 17 00:00:00 2001 From: Michael de Gans Date: Wed, 11 Sep 2024 22:26:57 -0700 Subject: [PATCH 2/7] Fix CI Coverage upload Coverage should not be uploaded on PR. --- .github/workflows/tests.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7504096..147da98 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -56,16 +56,17 @@ jobs: - name: Run tests run: cargo test --all-features --verbose + # This should only happen on push to main. PRs should not upload coverage. - name: Install tarpaulin - if: matrix.os == 'ubuntu-latest' + if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' run: cargo install cargo-tarpaulin - name: Run tarpaulin - if: matrix.os == 'ubuntu-latest' + if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' run: cargo tarpaulin --out Xml --all-features - name: Upload coverage to Codecov - if: matrix.os == 'ubuntu-latest' + if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' uses: codecov/codecov-action@v2 with: files: ./cobertura.xml From f50eef550aa09ce7a8e3612930a8cc7e416470b9 Mon Sep 17 00:00:00 2001 From: Michael de Gans Date: Wed, 11 Sep 2024 23:25:03 -0700 Subject: [PATCH 3/7] improve `Request` markdown - "System" now has a H3 header like the rest. --- src/request.rs | 175 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 170 insertions(+), 5 deletions(-) diff --git a/src/request.rs b/src/request.rs index 6d52ea9..74c399b 100644 --- a/src/request.rs +++ b/src/request.rs @@ -511,17 +511,42 @@ impl Request { #[cfg(feature = "markdown")] impl crate::markdown::ToMarkdown for Request { + /// Format the [`Request`] chat as markdown in OpenAI style. H3 headings are + /// used for "System", "Tool", "User", and "Assistant" messages even though + /// technically there are only [`User`] and [`Assistant`] [`Role`]s. + /// + /// [`User`]: message::Role::User + /// [`Assistant`]: message::Role::Assistant + /// [`Role`]: message::Role fn markdown_events_custom<'a>( &'a self, options: &'a crate::markdown::Options, ) -> Box> + 'a> { - let system = if let Some(system) = self - .system - .as_ref() - .map(|s| s.markdown_events_custom(options)) + use pulldown_cmark::{Event, HeadingLevel, Tag, TagEnd}; + + // TODO: Add the title if there is metadata for it. Also add a metadata + // option to Options to include arbitrary metadata. In my use case I am + // feeding the markdown to another model that will make use of this data + // so it does need to be included. + + let system: Box>> = if let Some(system) = + self.system + .as_ref() + .map(|s| s.markdown_events_custom(options)) { if options.system { - system + let header = [ + Event::Start(Tag::Heading { + level: HeadingLevel::H3, + id: None, + classes: vec![], + attrs: vec![], + }), + Event::Text("System".into()), + Event::End(TagEnd::Heading(HeadingLevel::H3)), + ]; + + Box::new(header.into_iter().chain(system)) } else { Box::new(std::iter::empty()) } @@ -809,5 +834,145 @@ mod tests { "Ping a server. Part deux." ); assert_eq!(request.tools.as_ref().unwrap()[0].input_schema, schema); + + // Test with a fallible tool. This should fail. + + let invalid = json!({ + "potato": "ping3", + "description": "Ping a server. Part trois.", + "input_schema": { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "The host to ping." + } + }, + "required": ["host"] + } + }); + let err = Request::default().try_add_tool(invalid.clone()); + if let Err(e) = err { + assert_eq!(e.to_string(), "missing field `name`"); + } else { + panic!("Expected an error."); + } + + let err = Request::default().try_tools([invalid]); + if let Err(e) = err { + assert_eq!(e.to_string(), "missing field `name`"); + } else { + panic!("Expected an error."); + } + } + + #[test] + fn test_temperature() { + let request = Request::default().temperature(Some(0.5)); + assert_eq!(request.temperature, Some(0.5)); + } + + #[test] + #[allow(unused_variables)] // because the compiler is silly sometimes + fn test_tool_choice() { + let choice = tool::Choice::Any; + let request = Request::default().tool_choice(choice); + assert!(matches!(request.tool_choice, Some(choice))); + } + + #[test] + fn test_top_k() { + let request = + Request::default().top_k(Some(NonZeroU16::new(5).unwrap())); + assert_eq!(request.top_k, Some(NonZeroU16::new(5).unwrap())); + } + + #[test] + fn test_top_p() { + let request = Request::default().top_p(Some(0.5)); + assert_eq!(request.top_p, Some(0.5)); + } + + #[test] + #[cfg(feature = "markdown")] + fn test_markdown() { + use crate::markdown::{Markdown, ToMarkdown}; + use message::Block; + + let request = Request::default() + .tools([Tool { + name: "ping".into(), + description: "Ping a server.".into(), + input_schema: json!({ + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "The host to ping." + } + }, + "required": ["host"] + }), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }]) + .system("You are a very succinct assistant.") + .messages([ + Message { + role: Role::User, + content: Content::text("Hello"), + }, + Message { + role: Role::Assistant, + content: Content::text("Hi"), + }, + Message { + role: Role::User, + content: Content::text("Call a tool."), + }, + Message { + role: Role::Assistant, + content: Block::ToolUse { + call: tool::Use { + id: "abc123".into(), + name: "ping".into(), + input: json!({ + "host": "example.com" + }), + cache_control: None, + }, + } + .into(), + }, + Message { + role: Role::User, + content: Block::ToolResult { + tool_use_id: "abc123".into(), + content: "Pinging example.com.".into(), + is_error: false, + cache_control: None, + } + .into(), + }, + Message { + role: Role::Assistant, + content: Content::text("Done."), + }, + ]); + + let opts = crate::markdown::Options::default() + .with_system() + .with_tool_use() + .with_tool_results(); + + let markdown: Markdown = request.markdown_custom(&opts); + + // OpenAI format. Anthropic doesn't have a "system" or "tool" role but + // we generate markdown like this because it's easier to read. The user + // does not submit a tool result, so it's confusing if the header is + // "User". + let expected = "### System\n\nYou are a very succinct assistant.\n\n### User\n\nHello\n\n### Assistant\n\nHi\n\n### User\n\nCall a tool.\n\n### Assistant\n\n````json\n{\"type\":\"tool_use\",\"id\":\"abc123\",\"name\":\"ping\",\"input\":{\"host\":\"example.com\"}}\n````\n\n### Tool\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"abc123\",\"content\":\"Pinging example.com.\",\"is_error\":false}\n````\n\n### Assistant\n\nDone."; + + assert_eq!(markdown.as_ref(), expected); } } From 77cc5a01554476123d605a4fae80a2334c58bd75 Mon Sep 17 00:00:00 2001 From: Michael de Gans Date: Thu, 12 Sep 2024 15:08:16 -0700 Subject: [PATCH 4/7] EOD WIP - More conversion shortcuts making example code much shorter - More coverage, even of dumb things like accessor methods. - Verbose Options to print System and Tool use as well as messages. - More Cow<'static, str> to avoid copies and allow `const` `Message`s and so on. Eventually we might add a lifetime generic so this could also be used for zero-copy, deserialization but the cognitive overhead is not worth it yet. - Update `README.md` code --- README.md | 37 +++------- examples/neologism.rs | 40 +++-------- examples/strawberry.rs | 44 ++++-------- src/markdown.rs | 21 ++++++ src/request.rs | 42 +++++------- src/request/message.rs | 146 +++++++++++++++++++++++++++++----------- src/response.rs | 136 +++++++++++++++++++++++++++++-------- src/response/message.rs | 84 ++++++++++++++++++++++- src/stream.rs | 28 +++++++- src/tool.rs | 34 +++++++++- 10 files changed, 429 insertions(+), 183 deletions(-) diff --git a/README.md b/README.md index 1717af4..37de5bc 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,13 @@ Is an unofficial simple, ergonomic, client for the Anthropic Messages API. ### Streaming ```rust -// Create a client. `key` will be consumed, zeroized, and stored securely. +// Create a client. The key is encrypted in memory and source string is zeroed. +// When requests are made, the key header is marked as sensitive. let client = Client::new(key)?; -// Request a stream of events or errors. `json!` can be used, a `Request`, or a -// combination of strings and concrete types like `Model`. All Client request -// methods accept anything serializable for maximum flexibility. +// Request a stream of events or errors. `json!` can be used, the `Request` +// builder pattern (shown in the `Single Message` example below), or anything +// serializable. let stream = client // Forces `stream=true` in the request. .stream(json!({ @@ -47,31 +48,13 @@ let content: String = stream ### Single Message ```rust -// Create a client. `key` will be consumed and zeroized. let client = Client::new(key)?; -// Request a single message. The parameters are the same as the streaming -// example above. If a value is `None` it will be omitted from the request. -// This is less flexible than json! but some may prefer it. A Builder pattern -// is not yet available but is planned to reduce the verbosity. +// Many common usage patterns are supported out of the box for building +// `Request`s, such as messages from an iterable of tuples of `Role` and +// `String`. let message = client - .message(Request { - model: Model::Sonnet35, - messages: vec![Message { - role: Role::User, - content: args.prompt.into(), - }], - max_tokens: 1000.try_into().unwrap(), - metadata: serde_json::Value::Null, - stop_sequences: None, - stream: None, - system: None, - temperature: Some(1.0), - tool_choice: None, - tools: None, - top_k: None, - top_p: None, - }) + .message(Request::default().messages([(Role::User, args.prompt)])) .await?; println!("{}", message); @@ -80,8 +63,10 @@ println!("{}", message); ## Features - [x] Async but does not _directly_ depend on tokio +- [x] Tool use, - [x] Streaming responses - [x] Message responses +- [x] Zero-copy where possible - [x] Image support with or without the `image` crate - [x] Markdown formatting of messages, including images - [x] Prompt caching support diff --git a/examples/neologism.rs b/examples/neologism.rs index 0c7d41c..419e5b1 100644 --- a/examples/neologism.rs +++ b/examples/neologism.rs @@ -1,15 +1,11 @@ //! See `source` for an example of [`Client::message`] using the "neologism //! creator" prompt. For a streaming example, see the `website_wizard` example. -// Note: This example uses blocking calls for simplicity such as `print` -// `read_to_string`, `stdin().lock()`, and `write`. In a real application, these -// should usually be replaced with async alternatives. - +// Note: This example uses blocking calls for simplicity such as `println!()` +// and `stdin().lock()`. In a real application, these should *usually* be +// replaced with async alternatives. use clap::Parser; -use misanthropic::{ - request::{message::Role, Message}, - Client, Model, Request, -}; +use misanthropic::{request::message::Role, Client, Request}; use std::io::{stdin, BufRead}; /// Invent new words and provide their definitions based on user-provided @@ -40,30 +36,16 @@ async fn main() -> Result<(), Box> { println!("Enter your API key:"); let key = stdin().lock().lines().next().unwrap()?; - // Create a client. `key` will be consumed and zeroized. + // Create a client. The key is encrypted in memory and source string is + // zeroed. When requests are made, the key header is marked as sensitive. let client = Client::new(key)?; - // Request a completion. `json!` can be used, `Request` or a combination of - // strings and types like `Model`. Client request methods accept anything - // serializable for maximum flexibility. + // Request a completion. `json!` can be used, the `Request` builder pattern, + // or anything serializable. Many common usage patterns are supported out of + // the box for building `Request`s, such as messages from a list of tuples + // of `Role` and `String`. let message = client - .message(Request { - model: Model::Sonnet35, - messages: vec![Message { - role: Role::User, - content: args.prompt.into(), - }], - max_tokens: 1000.try_into().unwrap(), - metadata: serde_json::Map::new(), - stop_sequences: None, - stream: None, - system: None, - temperature: Some(1.0), - tool_choice: None, - tools: None, - top_k: None, - top_p: None, - }) + .message(Request::default().messages([(Role::User, args.prompt)])) .await?; println!("{}", message); diff --git a/examples/strawberry.rs b/examples/strawberry.rs index 99c6f2c..4dbd552 100644 --- a/examples/strawberry.rs +++ b/examples/strawberry.rs @@ -8,11 +8,8 @@ use std::io::BufRead; use clap::Parser; use misanthropic::{ json, - markdown::{self, ToMarkdown}, - request::{ - message::{Block, Role}, - Message, - }, + markdown::ToMarkdown, + request::{message::Role, Message}, response, tool, Client, Request, Tool, }; @@ -66,18 +63,16 @@ pub fn handle_tool_call(call: &tool::Use) -> Result { ) { let count = count_letters(letter, string.into()); - Ok(Message { - role: Role::User, - content: Block::ToolResult { - tool_use_id: call.id.clone(), - content: count.to_string().into(), - is_error: false, - #[cfg(feature = "prompt-caching")] - cache_control: None, - } - // A Content Block is always convertable into Content. - .into(), - }) + Ok(tool::Result { + tool_use_id: call.id.clone(), + content: count.to_string().into(), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + // A `tool::Result` is always convertable to a `Message`. The `Role` is + // always `User` and the `Content` is always a `Block::ToolResult`. + .into()) } else { // Optionally, we could always return a Message and inform the Assistant // that they called the tool incorrectly so they can try again. @@ -123,10 +118,7 @@ async fn main() -> Result<(), Box> { // Inform the assistant about their limitations. }).system("You are a helpful assistant. You cannot count letters in a word by yourself because you see in tokens, not letters. Use the `count_letters` tool to overcome this limitation.") // Add user input. - .add_message(Message { - role: Role::User, - content: args.prompt.into(), - }); + .add_message((Role::User, args.prompt)); // Generate the next message in the chat. let message = client.message(&chat).await?; @@ -152,15 +144,7 @@ async fn main() -> Result<(), Box> { // default display also renders markdown, but without system prompt and // tool use information. chat.messages.push(message.into()); - println!( - "{}", - chat.markdown_custom( - &markdown::Options::default() - .with_system() - .with_tool_use() - .with_tool_results() - ) - ); + println!("{}", chat.markdown_verbose()); } else { // Just print the message content. The response `Message` contains the // `request::Message` with a `Role` and `Content`. The message can also diff --git a/src/markdown.rs b/src/markdown.rs index 70254fc..21de394 100644 --- a/src/markdown.rs +++ b/src/markdown.rs @@ -10,9 +10,20 @@ pub const DEFAULT_OPTIONS: Options = Options { system: false, }; +/// Verbose [`Options`] +pub const VERBOSE_OPTIONS: Options = Options { + inner: pulldown_cmark::Options::empty(), + tool_use: true, + tool_results: true, + system: true, +}; + /// A static reference to the default [`Options`]. pub static DEFAULT_OPTIONS_REF: &'static Options = &DEFAULT_OPTIONS; +/// A static reference to the verbose [`Options`]. +pub static VERBOSE_OPTIONS_REF: &'static Options = &VERBOSE_OPTIONS; + mod serde_inner { use super::*; @@ -56,6 +67,11 @@ pub struct Options { } impl Options { + /// Maximum verbosity + pub fn verbose() -> Self { + VERBOSE_OPTIONS + } + /// Set [`tool_use`] to true /// /// [`tool_use`]: Options::tool_use @@ -162,6 +178,11 @@ pub trait ToMarkdown { self.markdown_events_custom(options).into() } + /// Render the type to a [`Markdown`] string with maximum verbosity. + fn markdown_verbose(&self) -> Markdown { + self.markdown_custom(VERBOSE_OPTIONS_REF) + } + /// Render the markdown to a type implementing [`std::fmt::Write`] with /// [`DEFAULT_OPTIONS`]. fn write_markdown( diff --git a/src/request.rs b/src/request.rs index 74c399b..952ebbc 100644 --- a/src/request.rs +++ b/src/request.rs @@ -897,7 +897,6 @@ mod tests { #[cfg(feature = "markdown")] fn test_markdown() { use crate::markdown::{Markdown, ToMarkdown}; - use message::Block; let request = Request::default() .tools([Tool { @@ -930,30 +929,23 @@ mod tests { role: Role::User, content: Content::text("Call a tool."), }, - Message { - role: Role::Assistant, - content: Block::ToolUse { - call: tool::Use { - id: "abc123".into(), - name: "ping".into(), - input: json!({ - "host": "example.com" - }), - cache_control: None, - }, - } - .into(), - }, - Message { - role: Role::User, - content: Block::ToolResult { - tool_use_id: "abc123".into(), - content: "Pinging example.com.".into(), - is_error: false, - cache_control: None, - } - .into(), - }, + tool::Use { + id: "abc123".into(), + name: "ping".into(), + input: json!({ + "host": "example.com" + }), + cache_control: None, + } + .into(), + tool::Result { + tool_use_id: "abc123".into(), + content: "Pinging example.com.".into(), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(), Message { role: Role::Assistant, content: Content::text("Done."), diff --git a/src/request/message.rs b/src/request/message.rs index a9cd478..180fdee 100644 --- a/src/request/message.rs +++ b/src/request/message.rs @@ -102,6 +102,15 @@ impl From for Message { } } +impl From<(Role, String)> for Message { + fn from((role, content): (Role, String)) -> Self { + Self { + role, + content: Content::SinglePart(content.into()), + } + } +} + impl From<(Role, Cow<'static, str>)> for Message { fn from((role, content): (Role, Cow<'static, str>)) -> Self { Self { @@ -120,6 +129,24 @@ impl From<(Role, &'static str)> for Message { } } +impl From for Message { + fn from(call: tool::Use) -> Self { + Message { + role: Role::Assistant, + content: call.into(), + } + } +} + +impl From for Message { + fn from(result: tool::Result) -> Self { + Message { + role: Role::User, + content: result.into(), + } + } +} + #[cfg(feature = "markdown")] impl crate::markdown::ToMarkdown for Message { /// Returns an iterator over the text as [`pulldown_cmark::Event`]s using @@ -135,7 +162,9 @@ impl crate::markdown::ToMarkdown for Message { let content = self.content.markdown_events_custom(options); let role = match self.content.last() { - Some(Block::ToolResult { is_error, .. }) => { + Some(Block::ToolResult { + result: tool::Result { is_error, .. }, + }) => { if *is_error { "Error" } else { @@ -165,7 +194,12 @@ impl std::fmt::Display for Message { /// Content of a [`Message`]. #[derive( - Debug, Serialize, Deserialize, derive_more::From, derive_more::IsVariant, + Clone, + Debug, + Serialize, + Deserialize, + derive_more::From, + derive_more::IsVariant, )] #[serde(rename_all = "snake_case")] #[serde(untagged)] @@ -360,8 +394,20 @@ impl From for Content { } } +impl From for Content { + fn from(call: tool::Use) -> Self { + Block::from(call).into() + } +} + +impl From for Content { + fn from(result: tool::Result) -> Self { + Block::from(result).into() + } +} + /// A [`Content`] [`Block`] of a [`Message`]. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[cfg_attr(not(feature = "markdown"), derive(derive_more::Display))] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] @@ -397,7 +443,7 @@ pub enum Block { ToolUse { /// Tool use input. #[serde(flatten)] - call: crate::tool::Use, + call: tool::Use, }, /// Result of a [`Tool`] call. This should only be used with the [`User`] /// role. @@ -406,16 +452,9 @@ pub enum Block { /// [`Tool`]: crate::Tool #[cfg_attr(not(feature = "markdown"), display(""))] ToolResult { - /// Unique Id for this tool call. - tool_use_id: String, - /// Output of the tool. - content: Content, - /// Whether the tool call result was an error. - is_error: bool, - /// Use prompt caching. See [`Block::cache`] for more information. - #[cfg(feature = "prompt-caching")] - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, + /// Tool result + #[serde(flatten)] + result: tool::Result, }, } @@ -505,7 +544,9 @@ impl Block { | Self::ToolUse { call: tool::Use { cache_control, .. }, } - | Self::ToolResult { cache_control, .. } => { + | Self::ToolResult { + result: tool::Result { cache_control, .. }, + } => { *cache_control = Some(CacheControl::Ephemeral); } } @@ -522,7 +563,9 @@ impl Block { | Self::ToolUse { call: tool::Use { cache_control, .. }, } - | Self::ToolResult { cache_control, .. } => cache_control.is_some(), + | Self::ToolResult { + result: tool::Result { cache_control, .. }, + } => cache_control.is_some(), } } @@ -634,6 +677,18 @@ impl From for Block { } } +impl From for Block { + fn from(call: tool::Use) -> Self { + Self::ToolUse { call } + } +} + +impl From for Block { + fn from(result: tool::Result) -> Self { + Self::ToolResult { result } + } +} + #[cfg(feature = "png")] impl From for Block { fn from(image: image::RgbaImage) -> Self { @@ -668,7 +723,7 @@ pub enum CacheControl { /// Image content for [`MultiPart`] [`Message`]s. /// /// [`MultiPart`]: Content::MultiPart -#[derive(Debug, Serialize, Deserialize, derive_more::Display)] +#[derive(Clone, Debug, Serialize, Deserialize, derive_more::Display)] #[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] @@ -1062,17 +1117,14 @@ mod tests { ); // Test tool result (success) - let message = Message { - role: Role::User, - content: Block::ToolResult { - tool_use_id: "tool_123".into(), - content: Content::SinglePart("Hello, world!".into()), - is_error: false, - #[cfg(feature = "prompt-caching")] - cache_control: None, - } - .into(), - }; + let message: Message = tool::Result { + tool_use_id: "tool_123".into(), + content: Content::SinglePart("Hello, world!".into()), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(); assert_eq!( message.markdown_custom(&opts).to_string(), @@ -1080,21 +1132,37 @@ mod tests { ); // Test tool result (error) - let message = Message { - role: Role::User, - content: Block::ToolResult { - tool_use_id: "tool_123".into(), - content: Content::SinglePart("Hello, world!".into()), - is_error: true, - #[cfg(feature = "prompt-caching")] - cache_control: None, - } - .into(), - }; + let message: Message = tool::Result { + tool_use_id: "tool_123".into(), + content: Content::SinglePart("Hello, world!".into()), + is_error: true, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(); assert_eq!( message.markdown_custom(&opts).to_string(), "### Error\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"tool_123\",\"content\":\"Hello, world!\",\"is_error\":true}\n````" ); } + + #[test] + fn test_block_tool_use() { + let expected = tool::Use { + id: "tool_123".into(), + name: "tool".into(), + input: serde_json::json!({}), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }; + + let block = Block::ToolUse { + call: expected.clone(), + }; + + assert_eq!(block.tool_use(), Some(&expected)); + } + + // TODO: Image tests } diff --git a/src/response.rs b/src/response.rs index f500f5d..3b7cc54 100644 --- a/src/response.rs +++ b/src/response.rs @@ -119,34 +119,116 @@ impl Response { mod tests { use super::*; - pub const RESPONSE_JSON: &str = r#"{ - "content": [ - { - "text": "Hi! My name is Claude.", - "type": "text" - } - ], - "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", - "model": "claude-3-5-sonnet-20240620", - "role": "assistant", - "stop_reason": "end_turn", - "stop_sequence": null, - "type": "message", - "usage": { - "input_tokens": 2095, - "output_tokens": 503 - } -}"#; + use std::borrow::Cow; + + const TEST_ID: &str = "test_id"; + + const CONTENT: &str = "Hello, world!"; + + const RESPONSE: Response = Response::Message { + message: Message { + id: Cow::Borrowed(TEST_ID), + message: request::Message { + role: request::message::Role::User, + content: request::message::Content::SinglePart(Cow::Borrowed( + CONTENT, + )), + }, + model: crate::Model::Sonnet35, + stop_reason: None, + stop_sequence: None, + usage: Usage { + input_tokens: 1, + #[cfg(feature = "prompt-caching")] + cache_creation_input_tokens: Some(2), + #[cfg(feature = "prompt-caching")] + cache_read_input_tokens: Some(3), + output_tokens: 4, + }, + }, + }; + + #[test] + fn test_into_stream() { + let mock_stream = crate::stream::tests::mock_stream(); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.into_stream().is_some()); + } + + #[test] + fn test_unwrap_stream() { + let mock_stream = crate::stream::tests::mock_stream(); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.into_stream().is_some()); + } + + #[test] + fn test_unwrap_message() { + assert_eq!( + RESPONSE.into_message().unwrap().content.to_string(), + "Hello, world!" + ); + } + + #[test] + fn test_message() { + assert_eq!( + RESPONSE.message().unwrap().content.to_string(), + "Hello, world!" + ); + } + + #[test] + fn test_into_message() { + assert_eq!( + RESPONSE.into_message().unwrap().content.to_string(), + "Hello, world!" + ); + } + + #[test] + fn test_into_response_message() { + assert_eq!( + RESPONSE + .into_response_message() + .unwrap() + .message + .content + .to_string(), + "Hello, world!" + ); + } + + #[test] + fn test_response_message() { + assert_eq!( + RESPONSE + .response_message() + .unwrap() + .message + .content + .to_string(), + "Hello, world!" + ); + } #[test] - fn deserialize_response_message() { - let message: Message = serde_json::from_str(RESPONSE_JSON).unwrap(); - assert_eq!(message.message.content.len(), 1); - assert_eq!(message.id, "msg_013Zva2CMHLNnXjNJJKqJ2EF"); - assert_eq!(message.model, crate::Model::Sonnet35); - assert!(matches!(message.stop_reason, Some(StopReason::EndTurn))); - assert_eq!(message.stop_sequence, None); - assert_eq!(message.usage.input_tokens, 2095); - assert_eq!(message.usage.output_tokens, 503); + fn test_unwrap_response_message() { + assert_eq!( + RESPONSE + .unwrap_response_message() + .message + .content + .to_string(), + "Hello, world!" + ); } } diff --git a/src/response/message.rs b/src/response/message.rs index b87907c..b99adbf 100644 --- a/src/response/message.rs +++ b/src/response/message.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::{request, stream::MessageDelta, Model}; use serde::{Deserialize, Serialize}; @@ -7,7 +9,7 @@ use serde::{Deserialize, Serialize}; #[display("{}", message)] pub struct Message { /// Unique `id` for the message. - pub id: String, + pub id: Cow<'static, str>, /// Inner [`request::Message`]. #[serde(flatten)] pub message: request::Message, @@ -19,7 +21,7 @@ pub struct Message { /// triggered it. /// /// [`StopSequence`]: StopReason::StopSequence - pub stop_sequence: Option, + pub stop_sequence: Option>, /// Usage statistics for the message. pub usage: Usage, } @@ -82,3 +84,81 @@ pub struct Usage { /// Number of output tokens generated. pub output_tokens: u64, } + +#[cfg(test)] +mod tests { + use super::*; + + // FIXME: This is Copilot generated JSON. It should be replaced with actual + // response JSON, however this is pretty close to what the actual JSON looks + // like. + pub const RESPONSE_JSON: &str = r#"{ + "content": [ + { + "text": "Hi! My name is Claude.", + "type": "text" + } + ], + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "model": "claude-3-5-sonnet-20240620", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": null, + "type": "message", + "usage": { + "input_tokens": 2095, + "output_tokens": 503 + } +}"#; + + #[test] + fn deserialize_response_message() { + let message: Message = serde_json::from_str(RESPONSE_JSON).unwrap(); + assert_eq!(message.message.content.len(), 1); + assert_eq!(message.id, "msg_013Zva2CMHLNnXjNJJKqJ2EF"); + assert_eq!(message.model, crate::Model::Sonnet35); + assert!(matches!(message.stop_reason, Some(StopReason::EndTurn))); + assert_eq!(message.stop_sequence, None); + assert_eq!(message.usage.input_tokens, 2095); + assert_eq!(message.usage.output_tokens, 503); + } + + #[test] + fn test_apply_delta() { + let mut message: Message = serde_json::from_str(RESPONSE_JSON).unwrap(); + let delta = MessageDelta { + stop_reason: Some(StopReason::MaxTokens), + stop_sequence: Some("sequence".into()), + usage: Some(Usage { + input_tokens: 100, + output_tokens: 200, + ..Default::default() + }), + }; + + message.apply_delta(delta); + + assert_eq!(message.stop_reason, Some(StopReason::MaxTokens)); + assert_eq!(message.stop_sequence, Some("sequence".into())); + assert_eq!(message.usage.input_tokens, 100); + assert_eq!(message.usage.output_tokens, 200); + } + + #[test] + fn test_tool_use() { + let mut message: Message = serde_json::from_str(RESPONSE_JSON).unwrap(); + assert!(message.tool_use().is_none()); + + message.stop_reason = Some(StopReason::ToolUse); + assert!(message.tool_use().is_none()); + + message.message.content.push(crate::tool::Use { + id: "id".into(), + name: "name".into(), + input: serde_json::json!({}), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }); + assert!(message.tool_use().is_some()); + } +} diff --git a/src/stream.rs b/src/stream.rs index ef37073..68c2e2a 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -2,7 +2,7 @@ //! associated types and errors only used when streaming. use futures::StreamExt; use serde::{Deserialize, Serialize}; -use std::pin::Pin; +use std::{borrow::Cow, pin::Pin}; #[allow(unused_imports)] // `Content`, `request` Used in docs. use crate::{ @@ -172,7 +172,7 @@ pub struct MessageDelta { pub stop_reason: Option, /// Stop sequence. #[serde(skip_serializing_if = "Option::is_none")] - pub stop_sequence: Option, + pub stop_sequence: Option>, /// Token usage. #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, @@ -302,7 +302,7 @@ impl futures::Stream for Stream { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; @@ -311,6 +311,28 @@ mod tests { pub const CONTENT_BLOCK_START: &str = "{\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"} }"; pub const CONTENT_BLOCK_DELTA: &str = "{\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Certainly! I\"} }"; + pub fn mock_stream() -> Stream { + let inner = futures::stream::iter( + [ + Ok(eventsource_stream::Event { + data: CONTENT_BLOCK_START.into(), + id: "123".into(), + event: "content_block_start".into(), + retry: None, + }), + Ok(eventsource_stream::Event { + data: CONTENT_BLOCK_DELTA.into(), + id: "123".into(), + event: "content_block_delta".into(), + retry: None, + }), + ] + .into_iter(), + ); + + Stream::new(inner) + } + #[test] fn test_content_block_start() { let event: Event = serde_json::from_str(CONTENT_BLOCK_START).unwrap(); diff --git a/src/tool.rs b/src/tool.rs index 9a5dac1..6efb7cf 100644 --- a/src/tool.rs +++ b/src/tool.rs @@ -1,6 +1,8 @@ //! [`Tool`] and tool [`Choice`] types for the Anthropic Messages API. use serde::{Deserialize, Serialize}; +use crate::request::message::Content; + /// Choice of [`Tool`] for a specific [`request::Message`]. /// /// [`request::Message`]: crate::request::Message @@ -70,7 +72,9 @@ impl Tool { impl TryFrom for Tool { type Error = serde_json::Error; - fn try_from(value: serde_json::Value) -> Result { + fn try_from( + value: serde_json::Value, + ) -> std::result::Result { serde_json::from_value(value) } } @@ -100,7 +104,9 @@ pub struct Use { impl TryFrom for Use { type Error = serde_json::Error; - fn try_from(value: serde_json::Value) -> Result { + fn try_from( + value: serde_json::Value, + ) -> std::result::Result { serde_json::from_value(value) } } @@ -138,3 +144,27 @@ impl std::fmt::Display for Use { self.write_markdown(f) } } + +/// Result of [`Tool`] [`Use`] sent back to the [`Assistant`] as a [`User`] +/// [`Message`]. +/// +/// [`Assistant`]: crate::request::message::Role::Assistant +/// [`User`]: crate::request::message::Role::User +/// [`Message`]: crate::request::Message +#[derive(Clone, Debug, Serialize, Deserialize)] +#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +// On the one hand this can clash with the `Result` type from the standard +// library, but on the other hand it's what the API uses, and I'm trying to +// be as faithful to the API as possible. +pub struct Result { + /// Unique Id for this tool call. + pub tool_use_id: String, + /// Output of the tool. + pub content: Content, + /// Whether the tool call result was an error. + pub is_error: bool, + /// Use prompt caching. See [`Block::cache`] for more information. + #[cfg(feature = "prompt-caching")] + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, +} From dd0557a55ce5fe83304f29442b6f6295b0270378 Mon Sep 17 00:00:00 2001 From: Michael de Gans Date: Thu, 12 Sep 2024 22:33:12 -0700 Subject: [PATCH 5/7] More test coverage - Test the Stream with an actual stream dump from the API docs. The flexibility of `Stream::new` has paid off. - Add itertools as a dev dependency for parsing the test SSE stream elegantly. --- Cargo.toml | 1 + src/response.rs | 8 ++- src/stream.rs | 120 +++++++++++++++++++++++++++++++++------ src/tool.rs | 56 ++++++++++++++++++ test/data/sse.stream.txt | 95 +++++++++++++++++++++++++++++++ 5 files changed, 262 insertions(+), 18 deletions(-) create mode 100644 test/data/sse.stream.txt diff --git a/Cargo.toml b/Cargo.toml index 08fcc51..ee45c97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ static_assertions = "1" clap = { version = "4", features = ["derive"] } env_logger = "0.11" tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +itertools = "0.13" [features] # rustls because I am sick of getting Dependabot alerts for OpenSSL. diff --git a/src/response.rs b/src/response.rs index 3b7cc54..806f64a 100644 --- a/src/response.rs +++ b/src/response.rs @@ -150,7 +150,9 @@ mod tests { #[test] fn test_into_stream() { - let mock_stream = crate::stream::tests::mock_stream(); + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); let response = Response::Stream { stream: mock_stream, @@ -161,7 +163,9 @@ mod tests { #[test] fn test_unwrap_stream() { - let mock_stream = crate::stream::tests::mock_stream(); + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); let response = Response::Stream { stream: mock_stream, diff --git a/src/stream.rs b/src/stream.rs index 68c2e2a..e211105 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -288,6 +288,16 @@ impl Stream { } }) } + + // TODO: Figure out an ergonomic way to handle tool use when streaming. We + // may need another wrapper stream to store json deltas until a full block + // is received. This would allow us to merge json deltas and then emit a + // tool use event. Emitting `Block`s might not be a bad idea, but it would + // delay the text output, which is the primary use case for streaming. Even + // though events can be made up of multiple text blocks, generally the model + // only generates a single block per message per type. Waiting for an entire + // text block would mean waiting for the entire message. Waiting on JSON, is + // however necessary since we can't do anything useful with partial JSON. } impl futures::Stream for Stream { @@ -304,6 +314,8 @@ impl futures::Stream for Stream { #[cfg(test)] pub(crate) mod tests { + use futures::TryStreamExt; + use super::*; // Actual JSON from the API. @@ -311,23 +323,26 @@ pub(crate) mod tests { pub const CONTENT_BLOCK_START: &str = "{\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"} }"; pub const CONTENT_BLOCK_DELTA: &str = "{\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Certainly! I\"} }"; - pub fn mock_stream() -> Stream { + /// Creates a mock stream from a string (likely `include_str!`). The string + /// should be a series of `event`, `data`, and empty lines (a SSE stream). + /// Anthropic provides such example data in the API documentation. + pub fn mock_stream(text: &'static str) -> Stream { + use itertools::Itertools; + + // TODO: one of every possible variants, even if it doesn't make sense. let inner = futures::stream::iter( - [ - Ok(eventsource_stream::Event { - data: CONTENT_BLOCK_START.into(), - id: "123".into(), - event: "content_block_start".into(), - retry: None, - }), + // first line should be `event`, second line should be `data`, third + // line should be empty. + text.lines().tuples().map(|(event, data, _empty)| { + assert!(_empty.is_empty()); + Ok(eventsource_stream::Event { - data: CONTENT_BLOCK_DELTA.into(), - id: "123".into(), - event: "content_block_delta".into(), + event: event.strip_prefix("event: ").unwrap().into(), + data: data.strip_prefix("data: ").unwrap().into(), + id: "".into(), retry: None, - }), - ] - .into_iter(), + }) + }), ); Stream::new(inner) @@ -383,7 +398,8 @@ pub(crate) mod tests { #[test] fn test_content_block_delta_merge() { - let delta = Delta::Text { + // Merge text deltas. + let text_delta = Delta::Text { text: "Certainly! I".into(), } .merge(Delta::Text { @@ -394,10 +410,82 @@ pub(crate) mod tests { .unwrap(); assert_eq!( - delta, + text_delta, Delta::Text { text: "Certainly! I can do".into() } ); + + // Merge JSON deltas. + let json_delta = Delta::Json { + partial_json: r#"{"key":"#.into(), + } + .merge(Delta::Json { + partial_json: r#""value"}"#.into(), + }) + .unwrap(); + + assert_eq!( + json_delta, + Delta::Json { + partial_json: r#"{"key":"value"}"#.into() + } + ); + + // Content mismatch. + let mismatch = json_delta.merge(text_delta).unwrap_err(); + + assert_eq!( + mismatch.to_string(), + ContentMismatch { + from: Delta::Text { + text: "Certainly! I can do".into() + }, + to: "Delta::Json" + } + .to_string() + ); + + // Other way around, for coverage. + let text_delta = Delta::Text { + text: "Certainly!".into(), + }; + let json_delta = Delta::Json { + partial_json: r#"{"key":"value"}"#.into(), + }; + + let mismatch = text_delta.merge(json_delta).unwrap_err(); + + assert_eq!( + mismatch.to_string(), + ContentMismatch { + from: Delta::Json { + partial_json: r#"{"key":"value"}"#.into() + }, + to: "Delta::Text" + } + .to_string() + ); + } + + #[tokio::test] + async fn test_stream() { + // sse.stream.txt is from the API docs and includes one of every event + // type, with the exception of fatal errors, but they all have the same + // structure, so if one works, they all should. It covers every code + // path in the `Stream` struct and every event type. + let stream = mock_stream(include_str!("../test/data/sse.stream.txt")); + + let text: String = stream + .filter_rate_limit() + .text() + .try_collect() + .await + .unwrap(); + + assert_eq!( + text, + "Okay, let's check the weather for San Francisco, CA:" + ); } } diff --git a/src/tool.rs b/src/tool.rs index 6efb7cf..402da3e 100644 --- a/src/tool.rs +++ b/src/tool.rs @@ -168,3 +168,59 @@ pub struct Result { #[serde(skip_serializing_if = "Option::is_none")] pub cache_control: Option, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn use_try_from_value() { + let value = serde_json::json!({ + "id": "test_id", + "name": "test_name", + "input": { + "test_key": "test_value" + } + }); + + let use_ = Use::try_from(value).unwrap(); + + assert_eq!(use_.id, "test_id"); + assert_eq!(use_.name, "test_name"); + assert_eq!( + use_.input, + serde_json::json!({ + "test_key": "test_value" + }) + ); + } + + #[test] + #[cfg(feature = "markdown")] + fn test_use_markdown() { + use crate::markdown::ToMarkdown; + + let use_ = Use { + id: "test_id".into(), + name: "test_name".into(), + input: serde_json::json!({ + "test_key": "test_value" + }), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }; + + let markdown = use_.markdown_verbose(); + + assert_eq!( + markdown.as_ref(), + "\n````json\n{\"id\":\"test_id\",\"name\":\"test_name\",\"input\":{\"test_key\":\"test_value\"}}\n````" + ); + + // By default the tool use is not included in the markdown, however this + // might change in the future. Really, our Display impl could just + // return an empty &str but this is more consistent with the rest of the + // crate. + assert_eq!(use_.to_string(), ""); + } +} diff --git a/test/data/sse.stream.txt b/test/data/sse.stream.txt new file mode 100644 index 0000000..b16c6db --- /dev/null +++ b/test/data/sse.stream.txt @@ -0,0 +1,95 @@ +event: message_start +data: {"type":"message_start","message":{"id":"msg_014p7gG3wDgGV9EUtLvnow3U","type":"message","role":"assistant","model":"claude-3-haiku-20240307","stop_sequence":null,"usage":{"input_tokens":472,"output_tokens":2},"content":[],"stop_reason":null}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: ping +data: {"type": "ping"} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Okay"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" let"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"'s"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" check"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" the"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" weather"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" San"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" Francisco"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" CA"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":":"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: content_block_start +data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\":"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" \"San"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" Francisc"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"o,"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" CA\""}} + +event: error +data: {"type": "error", "error": {"type": "rate_limit_error", "message": "Rate limit exceeded"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":", "}} + +event: error +data: {"type": "error", "error": {"type": "overloaded_error", "message": "Overloaded"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\"unit\": \"fah"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"renheit\"}"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":1} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":89}} + +event: message_stop +data: {"type":"message_stop"} From bc6f6e2eafaabdb5ccd8730010658f47d9adb3ab Mon Sep 17 00:00:00 2001 From: Michael de Gans Date: Thu, 12 Sep 2024 22:40:33 -0700 Subject: [PATCH 6/7] Fix readme codecov badge link it won't be valid until we push to main, but this should be correct. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 37de5bc..9b3e3b7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # `misanthropic` ![Build Status](https://github.com/mdegans/misanthropic/actions/workflows/tests.yaml/badge.svg) -[![codecov](https://codecov.io/gh/mdegans/misanthropic/branch/main/graph/badge.svg)](https://codecov.io/gh/your-username/your-repo) +[![codecov](https://codecov.io/gh/mdegans/misanthropic/branch/main/graph/badge.svg)](https://codecov.io/gh/mdegans/misanthropic) Is an unofficial simple, ergonomic, client for the Anthropic Messages API. From 4ae9de83e958d96bd430ac198cbae157da5d80b9 Mon Sep 17 00:00:00 2001 From: Michael de Gans Date: Fri, 13 Sep 2024 16:04:49 -0700 Subject: [PATCH 7/7] More coverage - Covered most of Client. Integration tests using Anthropic's service. - Covered more of `message.rs`. Found an image Display bug. - Switch coverage generation to `llvm_cov`. Tarpaulin is shit and can't run all tests at once. --- .github/workflows/tests.yaml | 25 ++++++---- .gitignore | 4 +- README.md | 2 +- src/client.rs | 83 ++++++++++++++++++++++++++++++++ src/key.rs | 34 +++++++++++-- src/markdown.rs | 32 ++++++++++++- src/request.rs | 4 ++ src/request/message.rs | 93 +++++++++++++++++++++++++++++++----- src/response.rs | 79 +++++++++++++++++++++++++++++- 9 files changed, 325 insertions(+), 31 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 147da98..a6d9a89 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -13,18 +13,15 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - rust: [stable, beta, nightly] steps: - name: Checkout code uses: actions/checkout@v2 - name: Set up Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ matrix.rust }} - profile: minimal - override: true + components: llvm-tools-preview - name: Cache cargo registry uses: actions/cache@v2 @@ -57,19 +54,27 @@ jobs: run: cargo test --all-features --verbose # This should only happen on push to main. PRs should not upload coverage. - - name: Install tarpaulin + - name: Install llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' - run: cargo install cargo-tarpaulin - - name: Run tarpaulin + - name: Install nextest + uses: taiki-e/install-action@nextest if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' - run: cargo tarpaulin --out Xml --all-features + + - name: Write API key to api.key + if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' + run: echo ${{ secrets.ANTHROPIC_API_KEY }} > api.key + + - name: Collect coverage data (including ignored tests) + if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' + run: cargo llvm-cov nextest --all-features --run-ignored all --lcov --output-path lcov.info - name: Upload coverage to Codecov if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' uses: codecov/codecov-action@v2 with: - files: ./cobertura.xml + files: lcov.info flags: unittests name: codecov-umbrella fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 0d677be..9fa9f16 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ /target Cargo.lock .vscode -cobertura.xml \ No newline at end of file +cobertura.xml +api.key +lcov.info diff --git a/README.md b/README.md index 9b3e3b7..31a249a 100644 --- a/README.md +++ b/README.md @@ -66,11 +66,11 @@ println!("{}", message); - [x] Tool use, - [x] Streaming responses - [x] Message responses -- [x] Zero-copy where possible - [x] Image support with or without the `image` crate - [x] Markdown formatting of messages, including images - [x] Prompt caching support - [x] Custom request and endpoint support +- [ ] Zero-copy serde - Coming soon! - [ ] Amazon Bedrock support - [ ] Vertex AI support diff --git a/src/client.rs b/src/client.rs index b63784d..6babe42 100644 --- a/src/client.rs +++ b/src/client.rs @@ -350,8 +350,12 @@ pub(crate) struct AnthropicErrorWrapper { #[cfg(test)] mod tests { + use futures::TryStreamExt; + use super::*; + // Test error deserialization. + #[test] fn test_anthropic_error_deserialize() { const INVALID_REQUEST: &str = @@ -455,4 +459,83 @@ mod tests { } ); } + + // Test the Client + + use crate::{request::message::Role, Request}; + + const CRATE_ROOT: &str = env!("CARGO_MANIFEST_DIR"); + + // Note: This is a real key but it's been disabled. As is warned in the + // docs above, do not use a string literal for a real key. There is no + // TryFrom<&'static str> for Key for this reason. + const FAKE_API_KEY: &str = "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"; + + // Error message for when the API key is not found. + const NO_API_KEY: &str = "API key not found. Create a file named `api.key` in the crate root with your API key."; + + // Load the API key from the `api.key` file in the crate root. + fn load_api_key() -> Option { + use std::fs::File; + use std::io::Read; + use std::path::Path; + + let mut file = + File::open(Path::new(CRATE_ROOT).join("api.key")).ok()?; + let mut key = String::new(); + file.read_to_string(&mut key).unwrap(); + Some(key.trim().to_string()) + } + + #[test] + fn test_client_new() { + let client = Client::new(FAKE_API_KEY.to_string()).unwrap(); + assert_eq!(client.key.to_string(), FAKE_API_KEY); + + // Apparently there isn't a way to check if the headers have been set + // on the client. Making a request returns a builder but the headers + // are not exposed. + } + + #[tokio::test] + #[ignore = "This test requires a real API key."] + async fn test_client_message() { + let key = load_api_key().expect(NO_API_KEY); + let client = Client::new(key).unwrap(); + + let message = client + .message(Request::default().messages([( + Role::User, + "Emit just the \"🙏\" emoji, please.", + )])) + .await + .unwrap(); + + assert_eq!(message.message.role, Role::Assistant); + assert!(message.to_string().contains("🙏")); + } + + #[tokio::test] + #[ignore = "This test requires a real API key."] + async fn test_client_stream() { + let key = load_api_key().expect(NO_API_KEY); + let client = Client::new(key).unwrap(); + + let stream = client + .stream(Request::default().messages([( + Role::User, + "Emit just the \"🙏\" emoji, please.", + )])) + .await + .unwrap(); + + let msg: String = stream + .filter_rate_limit() + .text() + .try_collect() + .await + .unwrap(); + + assert!(msg.contains("🙏")); + } } diff --git a/src/key.rs b/src/key.rs index 0510513..5d19d95 100644 --- a/src/key.rs +++ b/src/key.rs @@ -15,8 +15,11 @@ pub type Arr = [u8; LEN]; /// /// [`key::LEN`]: LEN #[derive(Debug, thiserror::Error)] -#[error("Invalid key length: {0} (expected {LEN})")] -pub struct InvalidKeyLength(usize); +#[error("Invalid key length: {actual} (expected {LEN})")] +pub struct InvalidKeyLength { + /// The incorrect actual length of the key. + pub actual: usize, +} /// Stores an Anthropic API key securely. The API key is encrypted in memory. /// The object features a [`Display`] implementation that can be used to write @@ -60,8 +63,9 @@ impl TryFrom> for Key { fn try_from(mut v: Vec) -> Result { let mut arr: Arr = [0; LEN]; if v.len() != LEN { + let actual = v.len(); v.zeroize(); - return Err(InvalidKeyLength(v.len())); + return Err(InvalidKeyLength { actual }); } arr.copy_from_slice(&v); @@ -105,3 +109,27 @@ impl std::fmt::Display for Key { write!(f, "{}", key_str) } } + +#[cfg(test)] +mod tests { + use super::*; + + // Note: This is a real key but it's been disabled. As is warned in the + // docs above, do not use a string literal for a real key. There is no + // TryFrom<&'static str> for Key for this reason. + const API_KEY: &str = "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"; + + #[test] + fn test_key() { + let key = Key::try_from(API_KEY.to_string()).unwrap(); + let key_str = key.to_string(); + assert_eq!(key_str, API_KEY); + } + + #[test] + fn test_invalid_key_length() { + let key = "test_key".to_string(); + let err = Key::try_from(key).unwrap_err(); + assert_eq!(err.to_string(), "Invalid key length: 8 (expected 108)"); + } +} diff --git a/src/markdown.rs b/src/markdown.rs index 21de394..f5e6883 100644 --- a/src/markdown.rs +++ b/src/markdown.rs @@ -170,7 +170,7 @@ impl PartialEq for Markdown { pub trait ToMarkdown { /// Render the type to a [`Markdown`] string with [`DEFAULT_OPTIONS`]. fn markdown(&self) -> Markdown { - self.markdown_custom(DEFAULT_OPTIONS_REF) + self.markdown_events().into() } /// Render the type to a [`Markdown`] string with custom [`Options`]. @@ -232,6 +232,8 @@ impl Default for Options { #[cfg(test)] mod tests { + use crate::request::{message::Role, Message}; + use super::*; use std::borrow::Borrow; @@ -246,6 +248,21 @@ mod tests { assert!(options == options2); } + #[test] + fn test_options_from_pulldown() { + let inner = pulldown_cmark::Options::empty(); + let options: Options = inner.into(); + assert_eq!(options.inner, inner); + } + + #[test] + fn test_options_verbose() { + let options = Options::verbose(); + assert!(options.tool_use); + assert!(options.tool_results); + assert!(options.system); + } + #[test] fn test_markdown() { let expected = "Hello, **world**!"; @@ -257,4 +274,17 @@ mod tests { let markdown: String = markdown.into(); assert_eq!(markdown, expected); } + + #[test] + fn test_message_markdown() { + let message = Message { + role: Role::User, + content: "Hello, **world**!".into(), + }; + + assert_eq!( + message.markdown().as_ref(), + "### User\n\nHello, **world**!" + ); + } } diff --git a/src/request.rs b/src/request.rs index 952ebbc..b6db3f3 100644 --- a/src/request.rs +++ b/src/request.rs @@ -721,6 +721,10 @@ mod tests { #[test] #[cfg(feature = "prompt-caching")] fn test_cache() { + // Test with nothing to cache. This should be a no-op. + let request = Request::default().cache(); + assert!(request == Request::default()); + // Test with no system prompt or messages that the call to cache affects // the tools. let request = Request::default().add_tool(Tool { diff --git a/src/request/message.rs b/src/request/message.rs index 180fdee..2934f7c 100644 --- a/src/request/message.rs +++ b/src/request/message.rs @@ -106,7 +106,7 @@ impl From<(Role, String)> for Message { fn from((role, content): (Role, String)) -> Self { Self { role, - content: Content::SinglePart(content.into()), + content: content.into(), } } } @@ -115,7 +115,7 @@ impl From<(Role, Cow<'static, str>)> for Message { fn from((role, content): (Role, Cow<'static, str>)) -> Self { Self { role, - content: Content::SinglePart(content), + content: content.into(), } } } @@ -124,7 +124,7 @@ impl From<(Role, &'static str)> for Message { fn from((role, content): (Role, &'static str)) -> Self { Self { role, - content: Content::SinglePart(Cow::Borrowed(content)), + content: content.into(), } } } @@ -254,12 +254,11 @@ impl Content { } /// Add a [`Block`] to the [`Content`]. If the [`Content`] is a - /// [`SinglePart`], it will be converted to a [`MultiPart`]. Returns the - /// index of the added [`Block`]. + /// [`SinglePart`], it will be converted to a [`MultiPart`]. /// /// [`SinglePart`]: Content::SinglePart /// [`MultiPart`]: Content::MultiPart - pub fn push

(&mut self, part: P) -> usize + pub fn push

(&mut self, part: P) where P: Into, { @@ -275,10 +274,6 @@ impl Content { if let Content::MultiPart(parts) = self { parts.push(part.into()); - - parts.len() - 1 - } else { - unreachable!("Content is not MultiPart"); } } @@ -693,9 +688,11 @@ impl From for Block { impl From for Block { fn from(image: image::RgbaImage) -> Self { Image::encode(MediaType::Png, image) - // Unwrap can never panic unless the PNG encoding fails. + // Unwrap can never panic unless the PNG encoding fails, which + // should really never happen, but no matter what we don't panic. .unwrap_or_else(|e| { - eprintln!("Error encoding image: {}", e); + #[cfg(feature = "log")] + log::error!("Error encoding image: {}", e); Image::from_parts(MediaType::Png, String::new()) }) .into() @@ -838,7 +835,11 @@ pub enum MediaType { impl std::fmt::Display for MediaType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Use serde to get the string representation. - write!(f, "{}", serde_json::to_string(self).unwrap()) + write!( + f, + "{}", + serde_json::to_string(self).unwrap().trim_matches('"') + ) } } @@ -899,6 +900,12 @@ mod tests { {"type": "text", "text": "How are you?"} ]"#; + #[test] + fn test_role_display() { + assert_eq!(Role::User.to_string(), "User"); + assert_eq!(Role::Assistant.to_string(), "Assistant"); + } + #[test] fn deserialize_content() { let content: Content = serde_json::from_str(CONTENT_SINGLE).unwrap(); @@ -919,6 +926,12 @@ mod tests { assert_eq!(message.to_string(), "### User\n\nHello, world"); } + #[test] + fn test_message_from_role_string_tuple() { + let message: Message = (Role::User, "Hello, world!".to_string()).into(); + assert_eq!(message.to_string(), "### User\n\nHello, world!"); + } + #[test] #[cfg(feature = "markdown")] fn test_merge_deltas() { @@ -1164,5 +1177,59 @@ mod tests { assert_eq!(block.tool_use(), Some(&expected)); } + #[test] + fn test_block_from_str() { + let block: Block = "Hello, world!".into(); + assert_eq!(block.to_string(), "Hello, world!"); + } + + #[test] + fn test_block_from_string() { + let block: Block = "Hello, world!".to_string().into(); + assert_eq!(block.to_string(), "Hello, world!"); + } + + #[test] + fn test_block_from_image() { + let image = Image::from_parts(MediaType::Png, "data".to_string()); + let block: Block = image.into(); + assert_eq!(block.to_string(), "![Image]()"); + } + // TODO: Image tests + #[test] + #[cfg(feature = "png")] + fn test_block_from_rgba_image() { + let image = image::RgbaImage::new(1, 1); + let block: Block = image.into(); + assert!(matches!(block, Block::Image { .. })); + } + + #[test] + #[cfg(feature = "png")] + fn test_block_from_dynamic_image() { + let image = image::DynamicImage::new_rgba8(1, 1); + let block: Block = image.into(); + assert!(matches!(block, Block::Image { .. })); + } + + #[test] + #[cfg(feature = "png")] + fn test_image_from_compressed() { + use std::io::Cursor; + + // Encode a sample image + let expected = image::RgbaImage::new(1, 1); + let mut encoded = Cursor::new(vec![]); + expected + .write_to(&mut encoded, image::ImageFormat::Png) + .unwrap(); + + // Decode the image + let image = + Image::from_compressed(MediaType::Png, encoded.into_inner()); + let actual: image::RgbaImage = image.try_into().unwrap(); + + assert_eq!(actual, expected); + } } diff --git a/src/response.rs b/src/response.rs index 806f64a..b5a29d2 100644 --- a/src/response.rs +++ b/src/response.rs @@ -159,6 +159,7 @@ mod tests { }; assert!(response.into_stream().is_some()); + assert!(RESPONSE.into_stream().is_none()); } #[test] @@ -171,23 +172,53 @@ mod tests { stream: mock_stream, }; - assert!(response.into_stream().is_some()); + let _stream = response.unwrap_stream(); + } + + #[test] + #[should_panic] + fn test_unwrap_stream_panics() { + let _panic = RESPONSE.unwrap_stream(); } #[test] fn test_unwrap_message() { assert_eq!( - RESPONSE.into_message().unwrap().content.to_string(), + RESPONSE.unwrap_message().content.to_string(), "Hello, world!" ); } + #[test] + #[should_panic] + fn test_unwrap_message_panics() { + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + let _panic = response.unwrap_message(); + } + #[test] fn test_message() { assert_eq!( RESPONSE.message().unwrap().content.to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.message().is_none()); } #[test] @@ -196,6 +227,16 @@ mod tests { RESPONSE.into_message().unwrap().content.to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.into_message().is_none()); } #[test] @@ -209,6 +250,16 @@ mod tests { .to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.into_response_message().is_none()); } #[test] @@ -222,6 +273,16 @@ mod tests { .to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.response_message().is_none()); } #[test] @@ -235,4 +296,18 @@ mod tests { "Hello, world!" ); } + + #[test] + #[should_panic] + fn test_unwrap_response_message_panics() { + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + let _panic = response.unwrap_response_message(); + } }