diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 147da98..a6d9a89 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -13,18 +13,15 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - rust: [stable, beta, nightly] steps: - name: Checkout code uses: actions/checkout@v2 - name: Set up Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ matrix.rust }} - profile: minimal - override: true + components: llvm-tools-preview - name: Cache cargo registry uses: actions/cache@v2 @@ -57,19 +54,27 @@ jobs: run: cargo test --all-features --verbose # This should only happen on push to main. PRs should not upload coverage. - - name: Install tarpaulin + - name: Install llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' - run: cargo install cargo-tarpaulin - - name: Run tarpaulin + - name: Install nextest + uses: taiki-e/install-action@nextest if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' - run: cargo tarpaulin --out Xml --all-features + + - name: Write API key to api.key + if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' + run: echo ${{ secrets.ANTHROPIC_API_KEY }} > api.key + + - name: Collect coverage data (including ignored tests) + if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' + run: cargo llvm-cov nextest --all-features --run-ignored all --lcov --output-path lcov.info - name: Upload coverage to Codecov if: matrix.os == 'ubuntu-latest' && github.event_name == 'push' uses: codecov/codecov-action@v2 with: - files: ./cobertura.xml + files: lcov.info flags: unittests name: codecov-umbrella fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 0d677be..9fa9f16 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ /target Cargo.lock .vscode -cobertura.xml \ No newline at end of file +cobertura.xml +api.key +lcov.info diff --git a/README.md b/README.md index 9b3e3b7..31a249a 100644 --- a/README.md +++ b/README.md @@ -66,11 +66,11 @@ println!("{}", message); - [x] Tool use, - [x] Streaming responses - [x] Message responses -- [x] Zero-copy where possible - [x] Image support with or without the `image` crate - [x] Markdown formatting of messages, including images - [x] Prompt caching support - [x] Custom request and endpoint support +- [ ] Zero-copy serde - Coming soon! - [ ] Amazon Bedrock support - [ ] Vertex AI support diff --git a/src/client.rs b/src/client.rs index b63784d..6babe42 100644 --- a/src/client.rs +++ b/src/client.rs @@ -350,8 +350,12 @@ pub(crate) struct AnthropicErrorWrapper { #[cfg(test)] mod tests { + use futures::TryStreamExt; + use super::*; + // Test error deserialization. + #[test] fn test_anthropic_error_deserialize() { const INVALID_REQUEST: &str = @@ -455,4 +459,83 @@ mod tests { } ); } + + // Test the Client + + use crate::{request::message::Role, Request}; + + const CRATE_ROOT: &str = env!("CARGO_MANIFEST_DIR"); + + // Note: This is a real key but it's been disabled. As is warned in the + // docs above, do not use a string literal for a real key. There is no + // TryFrom<&'static str> for Key for this reason. + const FAKE_API_KEY: &str = "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"; + + // Error message for when the API key is not found. + const NO_API_KEY: &str = "API key not found. Create a file named `api.key` in the crate root with your API key."; + + // Load the API key from the `api.key` file in the crate root. + fn load_api_key() -> Option { + use std::fs::File; + use std::io::Read; + use std::path::Path; + + let mut file = + File::open(Path::new(CRATE_ROOT).join("api.key")).ok()?; + let mut key = String::new(); + file.read_to_string(&mut key).unwrap(); + Some(key.trim().to_string()) + } + + #[test] + fn test_client_new() { + let client = Client::new(FAKE_API_KEY.to_string()).unwrap(); + assert_eq!(client.key.to_string(), FAKE_API_KEY); + + // Apparently there isn't a way to check if the headers have been set + // on the client. Making a request returns a builder but the headers + // are not exposed. + } + + #[tokio::test] + #[ignore = "This test requires a real API key."] + async fn test_client_message() { + let key = load_api_key().expect(NO_API_KEY); + let client = Client::new(key).unwrap(); + + let message = client + .message(Request::default().messages([( + Role::User, + "Emit just the \"🙏\" emoji, please.", + )])) + .await + .unwrap(); + + assert_eq!(message.message.role, Role::Assistant); + assert!(message.to_string().contains("🙏")); + } + + #[tokio::test] + #[ignore = "This test requires a real API key."] + async fn test_client_stream() { + let key = load_api_key().expect(NO_API_KEY); + let client = Client::new(key).unwrap(); + + let stream = client + .stream(Request::default().messages([( + Role::User, + "Emit just the \"🙏\" emoji, please.", + )])) + .await + .unwrap(); + + let msg: String = stream + .filter_rate_limit() + .text() + .try_collect() + .await + .unwrap(); + + assert!(msg.contains("🙏")); + } } diff --git a/src/key.rs b/src/key.rs index 0510513..5d19d95 100644 --- a/src/key.rs +++ b/src/key.rs @@ -15,8 +15,11 @@ pub type Arr = [u8; LEN]; /// /// [`key::LEN`]: LEN #[derive(Debug, thiserror::Error)] -#[error("Invalid key length: {0} (expected {LEN})")] -pub struct InvalidKeyLength(usize); +#[error("Invalid key length: {actual} (expected {LEN})")] +pub struct InvalidKeyLength { + /// The incorrect actual length of the key. + pub actual: usize, +} /// Stores an Anthropic API key securely. The API key is encrypted in memory. /// The object features a [`Display`] implementation that can be used to write @@ -60,8 +63,9 @@ impl TryFrom> for Key { fn try_from(mut v: Vec) -> Result { let mut arr: Arr = [0; LEN]; if v.len() != LEN { + let actual = v.len(); v.zeroize(); - return Err(InvalidKeyLength(v.len())); + return Err(InvalidKeyLength { actual }); } arr.copy_from_slice(&v); @@ -105,3 +109,27 @@ impl std::fmt::Display for Key { write!(f, "{}", key_str) } } + +#[cfg(test)] +mod tests { + use super::*; + + // Note: This is a real key but it's been disabled. As is warned in the + // docs above, do not use a string literal for a real key. There is no + // TryFrom<&'static str> for Key for this reason. + const API_KEY: &str = "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"; + + #[test] + fn test_key() { + let key = Key::try_from(API_KEY.to_string()).unwrap(); + let key_str = key.to_string(); + assert_eq!(key_str, API_KEY); + } + + #[test] + fn test_invalid_key_length() { + let key = "test_key".to_string(); + let err = Key::try_from(key).unwrap_err(); + assert_eq!(err.to_string(), "Invalid key length: 8 (expected 108)"); + } +} diff --git a/src/markdown.rs b/src/markdown.rs index 21de394..f5e6883 100644 --- a/src/markdown.rs +++ b/src/markdown.rs @@ -170,7 +170,7 @@ impl PartialEq for Markdown { pub trait ToMarkdown { /// Render the type to a [`Markdown`] string with [`DEFAULT_OPTIONS`]. fn markdown(&self) -> Markdown { - self.markdown_custom(DEFAULT_OPTIONS_REF) + self.markdown_events().into() } /// Render the type to a [`Markdown`] string with custom [`Options`]. @@ -232,6 +232,8 @@ impl Default for Options { #[cfg(test)] mod tests { + use crate::request::{message::Role, Message}; + use super::*; use std::borrow::Borrow; @@ -246,6 +248,21 @@ mod tests { assert!(options == options2); } + #[test] + fn test_options_from_pulldown() { + let inner = pulldown_cmark::Options::empty(); + let options: Options = inner.into(); + assert_eq!(options.inner, inner); + } + + #[test] + fn test_options_verbose() { + let options = Options::verbose(); + assert!(options.tool_use); + assert!(options.tool_results); + assert!(options.system); + } + #[test] fn test_markdown() { let expected = "Hello, **world**!"; @@ -257,4 +274,17 @@ mod tests { let markdown: String = markdown.into(); assert_eq!(markdown, expected); } + + #[test] + fn test_message_markdown() { + let message = Message { + role: Role::User, + content: "Hello, **world**!".into(), + }; + + assert_eq!( + message.markdown().as_ref(), + "### User\n\nHello, **world**!" + ); + } } diff --git a/src/request.rs b/src/request.rs index 952ebbc..b6db3f3 100644 --- a/src/request.rs +++ b/src/request.rs @@ -721,6 +721,10 @@ mod tests { #[test] #[cfg(feature = "prompt-caching")] fn test_cache() { + // Test with nothing to cache. This should be a no-op. + let request = Request::default().cache(); + assert!(request == Request::default()); + // Test with no system prompt or messages that the call to cache affects // the tools. let request = Request::default().add_tool(Tool { diff --git a/src/request/message.rs b/src/request/message.rs index 180fdee..2934f7c 100644 --- a/src/request/message.rs +++ b/src/request/message.rs @@ -106,7 +106,7 @@ impl From<(Role, String)> for Message { fn from((role, content): (Role, String)) -> Self { Self { role, - content: Content::SinglePart(content.into()), + content: content.into(), } } } @@ -115,7 +115,7 @@ impl From<(Role, Cow<'static, str>)> for Message { fn from((role, content): (Role, Cow<'static, str>)) -> Self { Self { role, - content: Content::SinglePart(content), + content: content.into(), } } } @@ -124,7 +124,7 @@ impl From<(Role, &'static str)> for Message { fn from((role, content): (Role, &'static str)) -> Self { Self { role, - content: Content::SinglePart(Cow::Borrowed(content)), + content: content.into(), } } } @@ -254,12 +254,11 @@ impl Content { } /// Add a [`Block`] to the [`Content`]. If the [`Content`] is a - /// [`SinglePart`], it will be converted to a [`MultiPart`]. Returns the - /// index of the added [`Block`]. + /// [`SinglePart`], it will be converted to a [`MultiPart`]. /// /// [`SinglePart`]: Content::SinglePart /// [`MultiPart`]: Content::MultiPart - pub fn push

(&mut self, part: P) -> usize + pub fn push

(&mut self, part: P) where P: Into, { @@ -275,10 +274,6 @@ impl Content { if let Content::MultiPart(parts) = self { parts.push(part.into()); - - parts.len() - 1 - } else { - unreachable!("Content is not MultiPart"); } } @@ -693,9 +688,11 @@ impl From for Block { impl From for Block { fn from(image: image::RgbaImage) -> Self { Image::encode(MediaType::Png, image) - // Unwrap can never panic unless the PNG encoding fails. + // Unwrap can never panic unless the PNG encoding fails, which + // should really never happen, but no matter what we don't panic. .unwrap_or_else(|e| { - eprintln!("Error encoding image: {}", e); + #[cfg(feature = "log")] + log::error!("Error encoding image: {}", e); Image::from_parts(MediaType::Png, String::new()) }) .into() @@ -838,7 +835,11 @@ pub enum MediaType { impl std::fmt::Display for MediaType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Use serde to get the string representation. - write!(f, "{}", serde_json::to_string(self).unwrap()) + write!( + f, + "{}", + serde_json::to_string(self).unwrap().trim_matches('"') + ) } } @@ -899,6 +900,12 @@ mod tests { {"type": "text", "text": "How are you?"} ]"#; + #[test] + fn test_role_display() { + assert_eq!(Role::User.to_string(), "User"); + assert_eq!(Role::Assistant.to_string(), "Assistant"); + } + #[test] fn deserialize_content() { let content: Content = serde_json::from_str(CONTENT_SINGLE).unwrap(); @@ -919,6 +926,12 @@ mod tests { assert_eq!(message.to_string(), "### User\n\nHello, world"); } + #[test] + fn test_message_from_role_string_tuple() { + let message: Message = (Role::User, "Hello, world!".to_string()).into(); + assert_eq!(message.to_string(), "### User\n\nHello, world!"); + } + #[test] #[cfg(feature = "markdown")] fn test_merge_deltas() { @@ -1164,5 +1177,59 @@ mod tests { assert_eq!(block.tool_use(), Some(&expected)); } + #[test] + fn test_block_from_str() { + let block: Block = "Hello, world!".into(); + assert_eq!(block.to_string(), "Hello, world!"); + } + + #[test] + fn test_block_from_string() { + let block: Block = "Hello, world!".to_string().into(); + assert_eq!(block.to_string(), "Hello, world!"); + } + + #[test] + fn test_block_from_image() { + let image = Image::from_parts(MediaType::Png, "data".to_string()); + let block: Block = image.into(); + assert_eq!(block.to_string(), "![Image]()"); + } + // TODO: Image tests + #[test] + #[cfg(feature = "png")] + fn test_block_from_rgba_image() { + let image = image::RgbaImage::new(1, 1); + let block: Block = image.into(); + assert!(matches!(block, Block::Image { .. })); + } + + #[test] + #[cfg(feature = "png")] + fn test_block_from_dynamic_image() { + let image = image::DynamicImage::new_rgba8(1, 1); + let block: Block = image.into(); + assert!(matches!(block, Block::Image { .. })); + } + + #[test] + #[cfg(feature = "png")] + fn test_image_from_compressed() { + use std::io::Cursor; + + // Encode a sample image + let expected = image::RgbaImage::new(1, 1); + let mut encoded = Cursor::new(vec![]); + expected + .write_to(&mut encoded, image::ImageFormat::Png) + .unwrap(); + + // Decode the image + let image = + Image::from_compressed(MediaType::Png, encoded.into_inner()); + let actual: image::RgbaImage = image.try_into().unwrap(); + + assert_eq!(actual, expected); + } } diff --git a/src/response.rs b/src/response.rs index 806f64a..b5a29d2 100644 --- a/src/response.rs +++ b/src/response.rs @@ -159,6 +159,7 @@ mod tests { }; assert!(response.into_stream().is_some()); + assert!(RESPONSE.into_stream().is_none()); } #[test] @@ -171,23 +172,53 @@ mod tests { stream: mock_stream, }; - assert!(response.into_stream().is_some()); + let _stream = response.unwrap_stream(); + } + + #[test] + #[should_panic] + fn test_unwrap_stream_panics() { + let _panic = RESPONSE.unwrap_stream(); } #[test] fn test_unwrap_message() { assert_eq!( - RESPONSE.into_message().unwrap().content.to_string(), + RESPONSE.unwrap_message().content.to_string(), "Hello, world!" ); } + #[test] + #[should_panic] + fn test_unwrap_message_panics() { + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + let _panic = response.unwrap_message(); + } + #[test] fn test_message() { assert_eq!( RESPONSE.message().unwrap().content.to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.message().is_none()); } #[test] @@ -196,6 +227,16 @@ mod tests { RESPONSE.into_message().unwrap().content.to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.into_message().is_none()); } #[test] @@ -209,6 +250,16 @@ mod tests { .to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.into_response_message().is_none()); } #[test] @@ -222,6 +273,16 @@ mod tests { .to_string(), "Hello, world!" ); + + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + assert!(response.response_message().is_none()); } #[test] @@ -235,4 +296,18 @@ mod tests { "Hello, world!" ); } + + #[test] + #[should_panic] + fn test_unwrap_response_message_panics() { + let mock_stream = crate::stream::tests::mock_stream(include_str!( + "../test/data/sse.stream.txt" + )); + + let response = Response::Stream { + stream: mock_stream, + }; + + let _panic = response.unwrap_response_message(); + } }