diff --git a/src/unzip/cloneable_seekable_reader.rs b/src/unzip/cloneable_seekable_reader.rs index 5212d10..dee61f6 100644 --- a/src/unzip/cloneable_seekable_reader.rs +++ b/src/unzip/cloneable_seekable_reader.rs @@ -7,20 +7,13 @@ // except according to those terms. use std::{ - fs::File, - io::{BufReader, Read, Seek, SeekFrom}, + io::{Read, Seek, SeekFrom}, sync::{Arc, Mutex}, }; -/// A trait to represent some reader which has a total length known in -/// advance. This is roughly equivalent to the nightly -/// [`Seek::stream_len`] API. -pub(crate) trait HasLength { - /// Return the current total length of this stream. - fn len(&self) -> u64; -} +use super::determine_stream_len; -struct Inner { +struct Inner { /// The underlying Read implementation. r: R, /// The position of r. @@ -29,7 +22,7 @@ struct Inner { len: Option, } -impl Inner { +impl Inner { fn new(r: R) -> Self { Self { r, @@ -39,14 +32,15 @@ impl Inner { } /// Get the length of the data stream. This is assumed to be constant. - fn len(&mut self) -> u64 { + fn len(&mut self) -> std::io::Result { + // Return cached size if let Some(len) = self.len { - return len; + return Ok(len); } - let len = self.r.len(); + let len = determine_stream_len(&mut self.r)?; self.len = Some(len); - len + Ok(len) } /// Read into the given buffer, starting at the given offset in the data stream. @@ -67,14 +61,14 @@ impl Inner { /// and thus can be cloned cheaply. It supports seeking; each cloned instance /// maintains its own pointer into the file, and the underlying instance /// is seeked prior to each read. -pub(crate) struct CloneableSeekableReader { +pub(crate) struct CloneableSeekableReader { /// The wrapper around the Read implementation, shared between threads. inner: Arc>>, /// The position of _this_ reader. pos: u64, } -impl Clone for CloneableSeekableReader { +impl Clone for CloneableSeekableReader { fn clone(&self) -> Self { Self { inner: self.inner.clone(), @@ -83,7 +77,7 @@ impl Clone for CloneableSeekableReader { } } -impl CloneableSeekableReader { +impl CloneableSeekableReader { /// Constructor. Takes ownership of the underlying `Read`. /// You should pass in only streams whose total length you expect /// to be fixed and unchanging. Odd behavior may occur if the length @@ -97,7 +91,7 @@ impl CloneableSeekableReader { } } -impl Read for CloneableSeekableReader { +impl Read for CloneableSeekableReader { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { let mut inner = self.inner.lock().unwrap(); let read_result = inner.read_at(self.pos, buf); @@ -114,12 +108,12 @@ impl Read for CloneableSeekableReader { } } -impl Seek for CloneableSeekableReader { +impl Seek for CloneableSeekableReader { fn seek(&mut self, pos: SeekFrom) -> std::io::Result { let new_pos = match pos { SeekFrom::Start(pos) => pos, SeekFrom::End(offset_from_end) => { - let file_len = self.inner.lock().unwrap().len(); + let file_len = self.inner.lock().unwrap().len()?; if -offset_from_end as u64 > file_len { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, @@ -146,30 +140,12 @@ impl Seek for CloneableSeekableReader { } } -impl HasLength for BufReader { - fn len(&self) -> u64 { - self.get_ref().len() - } -} - -impl HasLength for File { - fn len(&self) -> u64 { - self.metadata().unwrap().len() - } -} - #[cfg(test)] mod test { - use super::{CloneableSeekableReader, HasLength}; + use super::CloneableSeekableReader; use std::io::{Cursor, Read, Seek, SeekFrom}; use test_log::test; - impl HasLength for Cursor> { - fn len(&self) -> u64 { - self.get_ref().len() as u64 - } - } - #[test] fn test_cloneable_seekable_reader() -> std::io::Result<()> { let buf: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; diff --git a/src/unzip/mod.rs b/src/unzip/mod.rs index 7703158..840c81c 100644 --- a/src/unzip/mod.rs +++ b/src/unzip/mod.rs @@ -14,7 +14,7 @@ mod seekable_http_reader; use std::{ borrow::Cow, fs::File, - io::{ErrorKind, Read, Seek}, + io::{ErrorKind, Read, Seek, SeekFrom}, path::{Path, PathBuf}, sync::{Arc, Mutex}, }; @@ -27,10 +27,16 @@ use crate::unzip::{ cloneable_seekable_reader::CloneableSeekableReader, progress_updater::ProgressUpdater, }; -use self::{ - cloneable_seekable_reader::HasLength, - seekable_http_reader::{AccessPattern, SeekableHttpReader, SeekableHttpReaderEngine}, -}; +use self::seekable_http_reader::{AccessPattern, SeekableHttpReader, SeekableHttpReaderEngine}; + +pub(crate) fn determine_stream_len(stream: &mut R) -> std::io::Result { + let old_pos = stream.stream_position()?; + let len = stream.seek(SeekFrom::End(0))?; + if old_pos != len { + stream.seek(SeekFrom::Start(old_pos))?; + } + Ok(len) +} /// Options for unzipping. pub struct UnzipOptions<'a, 'b> { @@ -159,11 +165,11 @@ impl UnzipEngineImpl for UnzipUriEngine { impl UnzipEngine { /// Create an unzip engine which knows how to unzip a file. - pub fn for_file(zipfile: File) -> Result { + pub fn for_file(mut zipfile: File) -> Result { // The following line doesn't actually seem to make any significant // performance difference. // let zipfile = BufReader::new(zipfile); - let compressed_length = zipfile.len(); + let compressed_length = determine_stream_len(&mut zipfile)?; let zipfile = CloneableSeekableReader::new(zipfile); Ok(Self { zipfile: Box::new(UnzipFileEngine(ZipArchive::new(zipfile)?)), @@ -209,7 +215,7 @@ impl UnzipEngine { let mut response = reqwest::blocking::get(uri)?; let mut tempfile = tempfile::tempfile()?; std::io::copy(&mut response, &mut tempfile)?; - let compressed_length = tempfile.len(); + let compressed_length = determine_stream_len(&mut tempfile)?; let zipfile = CloneableSeekableReader::new(tempfile); ( compressed_length, diff --git a/src/unzip/seekable_http_reader.rs b/src/unzip/seekable_http_reader.rs index 6e3f2ef..61cd61a 100644 --- a/src/unzip/seekable_http_reader.rs +++ b/src/unzip/seekable_http_reader.rs @@ -18,10 +18,7 @@ use ranges::Ranges; use reqwest::blocking::Response; use thiserror::Error; -use super::{ - cloneable_seekable_reader::HasLength, - http_range_reader::{self, RangeFetcher}, -}; +use super::http_range_reader::{self, RangeFetcher}; /// This is how much we read from the underlying HTTP stream in a given thread, /// before signalling other threads that they may wish to continue with their @@ -636,12 +633,6 @@ impl Read for SeekableHttpReader { } } -impl HasLength for SeekableHttpReader { - fn len(&self) -> u64 { - self.engine.len() - } -} - #[cfg(test)] mod tests { use ripunzip_test_utils::{ExpectedRange, RangeAwareResponse, RangeAwareResponseType};