diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 06fa2cf..073da68 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ on: merge_group: env: - RUST_MIN: "1.75" + RUST_MIN: "1.80" jobs: test: diff --git a/Cargo.toml b/Cargo.toml index b0b8353..e705a76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ edition = "2021" name = "stream-download" version = "0.15.1" -rust-version = "1.75.0" +rust-version = "1.80.0" authors = ["Austin Schey "] license = "MIT OR Apache-2.0" readme = "README.md" @@ -27,7 +27,6 @@ reqwest = { version = "0.12.2", features = [ "stream", ], default-features = false, optional = true } reqwest-middleware = { version = ">=0.3,<0.5", optional = true } -tap = "1.0.1" tempfile = { version = "3.1", optional = true } thiserror = "2.0.1" tokio = { version = "1.27", features = ["sync", "macros", "rt", "time"] } diff --git a/README.md b/README.md index 69233e1..bbd928e 100644 --- a/README.md +++ b/README.md @@ -370,4 +370,4 @@ for dynamically modifying each HTTP request. ## Supported Rust Versions -The MSRV is currently `1.75.0`. +The MSRV is currently `1.80.0`. diff --git a/src/http/reqwest_client.rs b/src/http/reqwest_client.rs index b8bb6f6..6b8138e 100644 --- a/src/http/reqwest_client.rs +++ b/src/http/reqwest_client.rs @@ -1,12 +1,11 @@ //! Adapters for using [`reqwest`] with `stream-download` use std::str::FromStr; -use std::sync::OnceLock; +use std::sync::LazyLock; use bytes::Bytes; use futures::Stream; use reqwest::header::{self, AsHeaderName, HeaderMap}; -use tap::TapFallible; use tracing::warn; use super::{DecodeError, RANGE_HEADER_KEY, format_range_header_bytes}; @@ -21,7 +20,7 @@ impl ResponseHeaders for HeaderMap { fn get_header_str(headers: &HeaderMap, key: K) -> Option<&str> { headers.get(key).and_then(|val| { val.to_str() - .tap_err(|e| warn!("error converting header value: {e:?}")) + .inspect_err(|e| warn!("error converting header value: {e:?}")) .ok() }) } @@ -64,7 +63,7 @@ impl ClientResponse for reqwest::Response { fn content_length(&self) -> Option { get_header_str(self.headers(), header::CONTENT_LENGTH).and_then(|content_length| { u64::from_str(content_length) - .tap_err(|e| warn!("invalid content length value: {e:?}")) + .inspect_err(|e| warn!("invalid content length value: {e:?}")) .ok() }) } @@ -96,7 +95,7 @@ impl ClientResponse for reqwest::Response { } // per reqwest's docs, it's advisable to create a single client and reuse it -static CLIENT: OnceLock = OnceLock::new(); +static CLIENT: LazyLock = LazyLock::new(reqwest::Client::new); impl Client for reqwest::Client { type Url = reqwest::Url; @@ -105,7 +104,7 @@ impl Client for reqwest::Client { type Headers = HeaderMap; fn create() -> Self { - CLIENT.get_or_init(Self::new).clone() + CLIENT.clone() } async fn get(&self, url: &Self::Url) -> Result { diff --git a/src/http/reqwest_middleware_client.rs b/src/http/reqwest_middleware_client.rs index a14d6b8..2b692f8 100644 --- a/src/http/reqwest_middleware_client.rs +++ b/src/http/reqwest_middleware_client.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, LazyLock}; use parking_lot::Mutex; use reqwest::header::HeaderMap; @@ -6,21 +6,14 @@ use reqwest_middleware::Middleware; use super::{Client, RANGE_HEADER_KEY, format_range_header_bytes}; -static DEFAULT_MIDDLEWARE: OnceLock>>> = - OnceLock::new(); - -fn get_middleware() -> &'static Mutex>> { - DEFAULT_MIDDLEWARE.get_or_init(|| Mutex::new([].into())) -} +static DEFAULT_MIDDLEWARE: LazyLock>>> = + LazyLock::new(|| Mutex::new([].into())); pub(crate) fn add_default_middleware(middleware: M) where M: Middleware, { - DEFAULT_MIDDLEWARE - .get_or_init(|| Mutex::new([].into())) - .lock() - .push(Arc::new(middleware)); + DEFAULT_MIDDLEWARE.lock().push(Arc::new(middleware)); } impl Client for reqwest_middleware::ClientWithMiddleware { @@ -32,7 +25,7 @@ impl Client for reqwest_middleware::ClientWithMiddleware { fn create() -> Self { Self::new( reqwest::Client::create(), - get_middleware().lock().clone().into_boxed_slice(), + DEFAULT_MIDDLEWARE.lock().clone().into_boxed_slice(), ) } diff --git a/src/lib.rs b/src/lib.rs index 2d8a272..8e3f747 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,6 @@ pub use settings::*; use source::handle::SourceHandle; use source::{DecodeError, Source, SourceStream}; use storage::StorageProvider; -use tap::Tap; use tokio_util::sync::CancellationToken; use tracing::{debug, error, instrument, trace}; @@ -449,8 +448,7 @@ impl StreamDownload

{ settings: Settings, ) -> Result> where - S: SourceStream, - S::Error: Debug + Send, + S: SourceStream, F: FnOnce() -> Fut + Send + 'static, Fut: Future> + Send, { @@ -509,7 +507,7 @@ impl StreamDownload

{ } fn handle_read(&mut self, buf: &mut [u8]) -> io::Result { - let res = self.output_reader.read(buf).tap(|l| { + let res = self.output_reader.read(buf).inspect(|l| { trace!(read_length = format!("{l:?}"), "returning read"); }); self.handle.notify_read(); @@ -633,7 +631,7 @@ impl Seek for StreamDownload

{ return self .output_reader .seek(SeekFrom::Start(absolute_seek_position)) - .tap(|p| debug!(position = format!("{p:?}"), "returning seek position")); + .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position")); } self.handle.request_position(absolute_seek_position); @@ -648,7 +646,7 @@ impl Seek for StreamDownload

{ self.output_reader .seek(SeekFrom::Start(absolute_seek_position)) - .tap(|p| debug!(position = format!("{p:?}"), "returning seek position")) + .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position")) } } diff --git a/src/process/mod.rs b/src/process/mod.rs index 5e7e288..1880537 100644 --- a/src/process/mod.rs +++ b/src/process/mod.rs @@ -26,7 +26,6 @@ use bytes::Bytes; pub use command_builder::*; pub use ffmpeg::*; use futures::Stream; -use tap::TapFallible; use tempfile::NamedTempFile; use tracing::{debug, error, warn}; pub use yt_dlp::*; @@ -209,17 +208,17 @@ impl ProcessStream { for file in &mut self.stderr_files { let _ = file .flush() - .tap_err(|e| error!("error flushing file: {e:?}")); + .inspect_err(|e| error!("error flushing file: {e:?}")); // Need to reopen the file to access the contents since it was written to from an // external process if let Ok(mut file_handle) = file .reopen() - .tap_err(|e| error!("error opening file: {e:?}")) + .inspect_err(|e| error!("error opening file: {e:?}")) { let mut buf = String::new(); let _ = file_handle .read_to_string(&mut buf) - .tap_err(|e| error!("error reading file: {e:?}")); + .inspect_err(|e| error!("error reading file: {e:?}")); warn!("stderr from child process: {buf}"); } } @@ -227,7 +226,9 @@ impl ProcessStream { fn close_stderr_files(&mut self) { for file in mem::take(&mut self.stderr_files) { - let _ = file.close().tap_err(|e| warn!("error closing file: {e:?}")); + let _ = file + .close() + .inspect_err(|e| warn!("error closing file: {e:?}")); } } } diff --git a/src/source/handle.rs b/src/source/handle.rs index 181c5eb..6f09925 100644 --- a/src/source/handle.rs +++ b/src/source/handle.rs @@ -7,7 +7,6 @@ use std::time::Instant; use parking_lot::{Condvar, Mutex, RwLock}; use rangemap::RangeSet; -use tap::TapFallible; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::{Notify, mpsc}; use tracing::{debug, error}; @@ -47,7 +46,7 @@ impl SourceHandle { pub(crate) fn seek(&self, position: u64) { self.seek_tx .try_send(position) - .tap_err(|e| { + .inspect_err(|e| { if let TrySendError::Full(capacity) = e { error!("Seek buffer full. Capacity: {capacity}"); } diff --git a/src/source/mod.rs b/src/source/mod.rs index 308bb4d..5727a28 100644 --- a/src/source/mod.rs +++ b/src/source/mod.rs @@ -14,7 +14,6 @@ use futures::{Future, Stream, StreamExt, TryStream}; use handle::{ DownloadStatus, Downloaded, NotifyRead, PositionReached, RequestedPosition, SourceHandle, }; -use tap::TapFallible; use tokio::sync::mpsc; use tokio::time::timeout; use tokio_util::sync::CancellationToken; @@ -133,9 +132,10 @@ pub(crate) struct Source { cancellation_token: CancellationToken, } -impl Source +impl Source where - S::Error: Debug, + S: SourceStream, + W: StorageWriter, { pub(crate) fn new( writer: W, @@ -255,7 +255,7 @@ where // we'll cap the reconnect time to prevent additional delays between reconnect attempts. let reconnect_pos = tokio::time::timeout(self.retry_timeout, stream.reconnect(pos)).await; if reconnect_pos - .tap_err(|e| warn!("error attempting to reconnect: {e:?}")) + .inspect_err(|e| warn!("error attempting to reconnect: {e:?}")) .is_ok() { if let Some(on_reconnect) = &mut self.on_reconnect { diff --git a/tests/integration_test.rs b/tests/integration_test.rs index c8527fc..1499fc3 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -6,7 +6,7 @@ use std::{fs, io}; use opendal::{Operator, services}; use rstest::rstest; -use setup::{Command, ErrorTestStorageProvider, SERVER_ADDR, SERVER_RT, TestClient}; +use setup::{Command, ErrorTestStorageProvider, SERVER_RT, TestClient, server_addr}; use stream_download::async_read::AsyncReadStreamParams; use stream_download::http::{HttpStream, HttpStreamError}; use stream_download::open_dal::{OpenDalStream, OpenDalStreamParams}; @@ -42,9 +42,9 @@ fn compare(a: impl Into>, b: impl Into>) { #[case(256*1024)] #[case(1024*1024)] fn new(#[case] prefetch_bytes: u64) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let mut reader = StreamDownload::new::>( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), MemoryStorageProvider, @@ -75,12 +75,12 @@ fn new_with_middleware(#[case] prefetch_bytes: u64) { use reqwest_retry::RetryTransientMiddleware; use reqwest_retry::policies::ExponentialBackoff; - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); Settings::add_default_middleware(RetryTransientMiddleware::new_with_policy(retry_policy)); let mut reader = StreamDownload::new_http_with_middleware( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), MemoryStorageProvider, @@ -108,9 +108,8 @@ fn new_with_middleware(#[case] prefetch_bytes: u64) { #[case(256*1024)] #[case(1024*1024)] fn open_dal_chunk_size(#[case] prefetch_bytes: u64, #[values(745, 1234, 4096)] chunk_size: usize) { - SERVER_RT.get().unwrap().block_on(async move { - let builder = - services::Http::default().endpoint(&format!("http://{}", SERVER_ADDR.get().unwrap())); + SERVER_RT.block_on(async move { + let builder = services::Http::default().endpoint(&format!("http://{}", server_addr())); let operator = Operator::new(builder).unwrap().finish(); let mut reader = StreamDownload::new_open_dal( OpenDalStreamParams::new(operator, "music.mp3") @@ -140,10 +139,10 @@ fn open_dal_chunk_size(#[case] prefetch_bytes: u64, #[values(745, 1234, 4096)] c #[case(256*1024)] #[case(1024*1024)] fn from_stream_http(#[case] prefetch_bytes: u64) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let stream = http::HttpStream::new( reqwest::Client::new(), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -190,9 +189,8 @@ fn from_stream_http(#[case] prefetch_bytes: u64) { #[case(256*1024)] #[case(1024*1024)] fn from_stream_open_dal(#[case] prefetch_bytes: u64) { - SERVER_RT.get().unwrap().block_on(async move { - let builder = - services::Http::default().endpoint(&format!("http://{}", SERVER_ADDR.get().unwrap())); + SERVER_RT.block_on(async move { + let builder = services::Http::default().endpoint(&format!("http://{}", server_addr())); let operator = Operator::new(builder).unwrap().finish(); let stream = OpenDalStream::new(OpenDalStreamParams::new(operator, "music.mp3")) .await @@ -225,9 +223,9 @@ fn from_stream_open_dal(#[case] prefetch_bytes: u64) { #[rstest] fn handle_error() { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let reader = StreamDownload::new_http( - format!("http://{}/invalid.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/invalid.mp3", server_addr()) .parse() .unwrap(), MemoryStorageProvider, @@ -249,9 +247,9 @@ fn basic_download( #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: impl StorageProvider + 'static, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), storage, @@ -283,9 +281,9 @@ fn tempfile_builder( )] storage: TempStorageProvider, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), storage, @@ -306,9 +304,9 @@ fn tempfile_builder( #[test] fn return_error() { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ErrorTestStorageProvider(MemoryStorageProvider), @@ -333,7 +331,7 @@ fn slow_download( #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: impl StorageProvider + 'static, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -355,7 +353,7 @@ fn slow_download( let mut reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, true), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -384,7 +382,7 @@ fn cancel_on_drop( #[values(0, 1, 256*1024, 1024*1024)] prefetch_bytes: u64, #[values(MemoryStorageProvider)] storage: impl StorageProvider + 'static, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -402,7 +400,7 @@ fn cancel_on_drop( let reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, true), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -433,7 +431,7 @@ fn retry_stuck_download( #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: impl StorageProvider + 'static, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -470,7 +468,7 @@ fn retry_stuck_download( let mut reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, true), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -502,10 +500,9 @@ fn bounded( #[values(256*1024, 300*1024)] bounded_length: usize, #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: T, ) where - T: StorageProvider, - ::Reader: RefUnwindSafe + UnwindSafe, + T: StorageProvider, { - let buf = SERVER_RT.get().unwrap().block_on(async move { + let buf = SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -530,7 +527,7 @@ fn bounded( let mut reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, false), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -571,10 +568,9 @@ fn adaptive( #[values(true, false)] has_content_length: bool, #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: T, ) where - T: StorageProvider + 'static, - ::Reader: RefUnwindSafe + UnwindSafe, + T: StorageProvider + 'static, { - let buf = SERVER_RT.get().unwrap().block_on(async move { + let buf = SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -596,7 +592,7 @@ fn adaptive( let mut reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, has_content_length), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -635,7 +631,7 @@ fn adaptive( #[rstest] fn bounded_seek_near_beginning() { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -650,7 +646,7 @@ fn bounded_seek_near_beginning() { let mut reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, false), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -711,9 +707,9 @@ fn backpressure( #[values(4096, 4096*2+1, 256*1024)] bounded_size: usize, #[values(1, 5)] multiplier: usize, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), BoundedStorageProvider::new( @@ -747,7 +743,7 @@ fn seek_basic( #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: impl StorageProvider + 'static, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -769,7 +765,7 @@ fn seek_basic( let mut reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, true), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -806,9 +802,8 @@ fn seek_basic_open_dal( #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: impl StorageProvider + 'static, ) { - SERVER_RT.get().unwrap().block_on(async move { - let builder = - services::Http::default().endpoint(&format!("http://{}", SERVER_ADDR.get().unwrap())); + SERVER_RT.block_on(async move { + let builder = services::Http::default().endpoint(&format!("http://{}", server_addr())); let operator = Operator::new(builder).unwrap().finish(); @@ -852,7 +847,7 @@ fn seek_all( #[values(TempStorageProvider::default(), MemoryStorageProvider)] storage: impl StorageProvider + 'static, ) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let (tx, mut rx) = mpsc::unbounded_channel::<(Command, oneshot::Sender)>(); let handle = tokio::spawn(async move { @@ -884,7 +879,7 @@ fn seek_all( let mut reader = StreamDownload::from_stream( http::HttpStream::new( TestClient::new(tx, true), - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), ) @@ -954,9 +949,9 @@ fn seek_all( #[case(256*1024)] #[case(1024*1024)] fn cancel_download(#[case] prefetch_bytes: u64) { - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), MemoryStorageProvider, @@ -988,7 +983,7 @@ fn cancel_download(#[case] prefetch_bytes: u64) { fn on_progress(#[case] prefetch_bytes: u64) { let (tx, mut rx) = mpsc::unbounded_channel::<(Option, StreamState)>(); - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let progress_task = tokio::spawn(async move { let next = rx.recv().await.unwrap(); assert!(matches!( @@ -1020,7 +1015,7 @@ fn on_progress(#[case] prefetch_bytes: u64) { } }); let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), MemoryStorageProvider, @@ -1054,7 +1049,7 @@ fn on_progress(#[case] prefetch_bytes: u64) { fn on_progress_no_prefetch() { let (tx, mut rx) = mpsc::unbounded_channel::<(Option, StreamState)>(); - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let progress_task = tokio::spawn(async move { let next = rx.recv().await.unwrap(); assert!(matches!( @@ -1073,7 +1068,7 @@ fn on_progress_no_prefetch() { } }); let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), MemoryStorageProvider, @@ -1108,7 +1103,7 @@ fn on_progress_no_prefetch() { fn on_progress_excessive_prefetch(#[case] prefetch_bytes: u64) { let (tx, mut rx) = mpsc::unbounded_channel::<(Option, StreamState)>(); - SERVER_RT.get().unwrap().block_on(async move { + SERVER_RT.block_on(async move { let progress_task = tokio::spawn(async move { let next = rx.recv().await.unwrap(); assert!(matches!( @@ -1127,7 +1122,7 @@ fn on_progress_excessive_prefetch(#[case] prefetch_bytes: u64) { } }); let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), MemoryStorageProvider, diff --git a/tests/proptest.rs b/tests/proptest.rs index 99e1567..5bd52f1 100644 --- a/tests/proptest.rs +++ b/tests/proptest.rs @@ -3,7 +3,7 @@ use std::num::NonZeroUsize; use std::{fs, io}; use proptest::prelude::*; -use setup::{SERVER_ADDR, SERVER_RT}; +use setup::{SERVER_RT, server_addr}; use stream_download::storage::bounded::BoundedStorageProvider; use stream_download::storage::memory::MemoryStorageProvider; use stream_download::{Settings, StreamDownload}; @@ -45,9 +45,9 @@ prop_compose! { proptest! { #[test] fn proptest(StreamParams { read_len, bounded_size, prefetch_bytes } in input_sizes()) { - let buf = SERVER_RT.get().unwrap().block_on(async move { + let buf = SERVER_RT.block_on(async move { let mut reader = StreamDownload::new_http( - format!("http://{}/music.mp3", SERVER_ADDR.get().unwrap()) + format!("http://{}/music.mp3", server_addr()) .parse() .unwrap(), BoundedStorageProvider::new(MemoryStorageProvider, diff --git a/tests/setup.rs b/tests/setup.rs index b41ed01..409354c 100644 --- a/tests/setup.rs +++ b/tests/setup.rs @@ -1,8 +1,8 @@ use std::io::{self, Seek, SeekFrom, Write}; -use std::net::SocketAddr; +use std::net::{SocketAddr, TcpListener}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, LazyLock}; use std::task::{Context, Poll}; use std::time::Duration; @@ -19,23 +19,27 @@ use tokio::sync::{mpsc, oneshot}; use tower_http::services::ServeDir; use tracing_subscriber::EnvFilter; -pub static SERVER_RT: OnceLock = OnceLock::new(); -pub static SERVER_ADDR: OnceLock = OnceLock::new(); +pub static SERVER_RT: LazyLock = LazyLock::new(|| Runtime::new().unwrap()); +pub static SERVER_LISTENER: LazyLock = LazyLock::new(|| { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + listener.set_nonblocking(true).unwrap(); + listener +}); + +pub fn server_addr() -> SocketAddr { + SERVER_LISTENER.local_addr().unwrap() +} #[ctor] fn setup() { setup_logger(); - let rt = SERVER_RT.get_or_init(|| Runtime::new().unwrap()); - let _guard = rt.enter(); - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - listener.set_nonblocking(true).unwrap(); + let _guard = SERVER_RT.enter(); - SERVER_ADDR.get_or_init(|| listener.local_addr().unwrap()); let service = ServeDir::new("./assets"); let router = Router::new().fallback_service(service); - - rt.spawn(async move { + let listener = SERVER_LISTENER.try_clone().unwrap(); + SERVER_RT.spawn(async move { let listener = tokio::net::TcpListener::from_std(listener).unwrap(); axum::serve(listener, router).await.unwrap(); });