Skip to content

Commit

Permalink
webpsan: remove async API
Browse files Browse the repository at this point in the history
Support for the async API bloats code size significantly even when you're only using the sync API.
  • Loading branch information
jessa0 committed Oct 19, 2023
1 parent d2bfcda commit 6886411
Show file tree
Hide file tree
Showing 14 changed files with 299 additions and 360 deletions.
59 changes: 59 additions & 0 deletions common/src/async_skip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//! Utility functions for the [`AsyncSkip`] trait.
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

use futures_util::io::BufReader;
use futures_util::{AsyncBufRead, AsyncRead};

use crate::AsyncSkip;

/// Poll skipping `amount` bytes in a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub fn poll_buf_skip<R: AsyncRead + AsyncSkip>(
mut reader: Pin<&mut BufReader<R>>,
cx: &mut Context<'_>,
amount: u64,
) -> Poll<io::Result<()>> {
let buf_len = reader.buffer().len();
if let Some(skip_amount) = amount.checked_sub(buf_len as u64) {
if skip_amount != 0 {
ready!(reader.as_mut().get_pin_mut().poll_skip(cx, skip_amount))?
}
}
reader.consume(buf_len.min(amount as usize));
Poll::Ready(Ok(()))
}

/// Skip `amount` bytes in a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub async fn buf_skip<R: AsyncRead + AsyncSkip>(mut reader: Pin<&mut BufReader<R>>, amount: u64) -> io::Result<()> {
poll_fn(|cx| poll_buf_skip(reader.as_mut(), cx, amount)).await
}

/// Poll the stream position for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub fn poll_buf_stream_position<R: AsyncRead + AsyncSkip>(
mut reader: Pin<&mut BufReader<R>>,
cx: &mut Context<'_>,
) -> Poll<io::Result<u64>> {
let stream_pos = ready!(reader.as_mut().get_pin_mut().poll_stream_position(cx))?;
Poll::Ready(Ok(stream_pos.saturating_sub(reader.buffer().len() as u64)))
}

/// Return the stream position for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub async fn buf_stream_position<R: AsyncRead + AsyncSkip>(mut reader: Pin<&mut BufReader<R>>) -> io::Result<u64> {
poll_fn(|cx| poll_buf_stream_position(reader.as_mut(), cx)).await
}

/// Poll the stream length for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub fn poll_buf_stream_len<R: AsyncRead + AsyncSkip>(
mut reader: Pin<&mut BufReader<R>>,
cx: &mut Context<'_>,
) -> Poll<io::Result<u64>> {
reader.as_mut().get_pin_mut().poll_stream_len(cx)
}

/// Return the stream length for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub async fn buf_stream_len<R: AsyncRead + AsyncSkip>(mut reader: Pin<&mut BufReader<R>>) -> io::Result<u64> {
poll_fn(|cx| poll_buf_stream_len(reader.as_mut(), cx)).await
}
54 changes: 3 additions & 51 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
#[macro_use]
pub mod macros;

pub mod async_skip;
pub mod error;
pub mod parse;
pub mod skip;
pub mod sync;
pub mod util;

use std::future::poll_fn;
use std::io;
use std::io::Seek;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

use futures_util::io::BufReader;
use futures_util::{AsyncBufRead, AsyncRead, AsyncSeek};
use futures_util::AsyncSeek;

//
// public types
Expand Down Expand Up @@ -71,54 +71,6 @@ pub trait AsyncSkip {
// public functions
//

/// Poll skipping `amount` bytes in a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub fn poll_buf_skip<R: AsyncRead + AsyncSkip>(
mut reader: Pin<&mut BufReader<R>>,
cx: &mut Context<'_>,
amount: u64,
) -> Poll<io::Result<()>> {
let buf_len = reader.buffer().len();
if let Some(skip_amount) = amount.checked_sub(buf_len as u64) {
if skip_amount != 0 {
ready!(reader.as_mut().get_pin_mut().poll_skip(cx, skip_amount))?
}
}
reader.consume(buf_len.min(amount as usize));
Poll::Ready(Ok(()))
}

/// Skip `amount` bytes in a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub async fn buf_skip<R: AsyncRead + AsyncSkip>(mut reader: Pin<&mut BufReader<R>>, amount: u64) -> io::Result<()> {
poll_fn(|cx| poll_buf_skip(reader.as_mut(), cx, amount)).await
}

/// Poll the stream position for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub fn poll_buf_stream_position<R: AsyncRead + AsyncSkip>(
mut reader: Pin<&mut BufReader<R>>,
cx: &mut Context<'_>,
) -> Poll<io::Result<u64>> {
let stream_pos = ready!(reader.as_mut().get_pin_mut().poll_stream_position(cx))?;
Poll::Ready(Ok(stream_pos.saturating_sub(reader.buffer().len() as u64)))
}

/// Return the stream position for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub async fn buf_stream_position<R: AsyncRead + AsyncSkip>(mut reader: Pin<&mut BufReader<R>>) -> io::Result<u64> {
poll_fn(|cx| poll_buf_stream_position(reader.as_mut(), cx)).await
}

/// Poll the stream length for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub fn poll_buf_stream_len<R: AsyncRead + AsyncSkip>(
mut reader: Pin<&mut BufReader<R>>,
cx: &mut Context<'_>,
) -> Poll<io::Result<u64>> {
reader.as_mut().get_pin_mut().poll_stream_len(cx)
}

/// Return the stream length for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
pub async fn buf_stream_len<R: AsyncRead + AsyncSkip>(mut reader: Pin<&mut BufReader<R>>) -> io::Result<u64> {
poll_fn(|cx| poll_buf_stream_len(reader.as_mut(), cx)).await
}

//
// Skip impls
//
Expand Down
29 changes: 29 additions & 0 deletions common/src/skip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//! Utility functions for the [`Skip`] trait.
use std::io;
use std::io::{BufRead, BufReader, Read};

use crate::Skip;

/// Skip `amount` bytes in a [`BufReader`] implementing [`Read`] + [`Skip`].
pub fn buf_skip<R: Read + Skip>(reader: &mut BufReader<R>, amount: u64) -> io::Result<()> {
let buf_len = reader.buffer().len();
if let Some(skip_amount) = amount.checked_sub(buf_len as u64) {
if skip_amount != 0 {
reader.get_mut().skip(skip_amount)?;
}
}
reader.consume(buf_len.min(amount as usize));
Ok(())
}

/// Return the stream position for a [`BufReader`] implementing [`Read`] + [`Skip`].
pub fn buf_stream_position<R: Read + Skip>(reader: &mut BufReader<R>) -> io::Result<u64> {
let stream_pos = reader.get_mut().stream_position()?;
Ok(stream_pos.saturating_sub(reader.buffer().len() as u64))
}

/// Return the stream length for a [`BufReader`] implementing [`Read`] + [`Skip`].
pub fn buf_stream_len<R: Read + Skip>(reader: &mut BufReader<R>) -> io::Result<u64> {
reader.get_mut().stream_len()
}
4 changes: 3 additions & 1 deletion mp4san/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ use derive_builder::Builder;
use derive_more::Display;
use futures_util::io::BufReader;
use futures_util::{pin_mut, AsyncBufReadExt, AsyncRead};
use mediasan_common::async_skip::{
buf_skip as skip, buf_stream_len as stream_len, buf_stream_position as stream_position,
};
use mediasan_common::sync;
use mediasan_common::util::{checked_add_signed, IoResultExt};
use mediasan_common::{buf_skip as skip, buf_stream_len as stream_len, buf_stream_position as stream_position};

use crate::error::Report;
use crate::parse::error::{MultipleBoxes, WhileParsingBox};
Expand Down
2 changes: 0 additions & 2 deletions webpsan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ bitstream-io = "1.7.0"
bytes = "1.3.0"
derive_builder = "0.12.0"
derive_more = "0.99.17"
futures-util = { version = "0.3.28", default-features = false, features = ["io"] }
log = "0.4.17"
mediasan-common = { path = "../common", version = "=0.4.0" }
num-integer = { version = "0.1.45", default-features = false }
num-traits = { version = "0.2.16", default-features = false }
pin-project = "1.1.3"
thiserror = "1.0.38"

[dev-dependencies]
Expand Down
7 changes: 3 additions & 4 deletions webpsan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ parser can be avoided.

## Usage

The main entry points to the sanitizer are [`sanitize`]/[`sanitize_async`], which take a [`Read`] + [`Skip`] input. The
[`Skip`] trait represents a subset of the [`Seek`] trait; an input stream which can be skipped forward, but not
necessarily seeked to arbitrary positions.
The main entry points to the sanitizer are [`sanitize`], which take a [`Read`] + [`Skip`] input. The [`Skip`] trait
represents a subset of the [`Seek`] trait; an input stream which can be skipped forward, but not necessarily seeked to
arbitrary positions.

```rust
let example_input = b"RIFF\x14\0\0\0WEBPVP8L\x08\0\0\0\x2f\0\0\0\0\x88\x88\x08";
Expand All @@ -23,7 +23,6 @@ types.
[Private Documentation](https://privacyresearchgroup.github.io/mp4san/private/webpsan/)

[`sanitize`]: https://privacyresearchgroup.github.io/mp4san/public/webpsan/fn.sanitize.html
[`sanitize_async`]: https://privacyresearchgroup.github.io/mp4san/public/webpsan/fn.sanitize_async.html
[`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html
[`Skip`]: https://privacyresearchgroup.github.io/mp4san/public/mediasan_common/trait.Skip.html
[`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html
Expand Down
24 changes: 11 additions & 13 deletions webpsan/benches/bitstream.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::task::{Context, Poll};
use std::{io, pin::Pin};
use std::io;
use std::io::Read;

use bitstream_io::LE;
use criterion::async_executor::FuturesExecutor;
use criterion::measurement::Measurement;
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkGroup, Criterion};
use futures_util::AsyncRead;
use webpsan::parse::{BitBufReader, CanonicalHuffmanTree};
use webpsan::Error;

Expand All @@ -19,9 +17,9 @@ criterion_main!(benches);

struct BlackBoxZeroesInput;

impl AsyncRead for BlackBoxZeroesInput {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
black_box(Poll::Ready(Ok(buf.len())))
impl Read for BlackBoxZeroesInput {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
black_box(Ok(buf.len()))
}
}

Expand All @@ -48,11 +46,11 @@ fn read_huffman<M: Measurement, S: Clone>(mut group: BenchmarkGroup<'_, M>, code
let setup = || BitBufReader::<_, LE>::with_capacity(BlackBoxZeroesInput, buf_len);
group.throughput(criterion::Throughput::Bytes(buf_len as u64));
group.bench_function("buf_read_huffman", |bencher| {
bencher.to_async(FuturesExecutor).iter_batched(
bencher.iter_batched(
setup,
|mut reader| async move {
|mut reader| {
if code.longest_code_len() != 0 {
reader.fill_buf().await?;
reader.fill_buf()?;
}
for _ in 0..buf_len * 8 {
black_box(reader.buf_read_huffman(code))?;
Expand All @@ -63,11 +61,11 @@ fn read_huffman<M: Measurement, S: Clone>(mut group: BenchmarkGroup<'_, M>, code
)
});
group.bench_function("read_huffman", |bencher| {
bencher.to_async(FuturesExecutor).iter_batched(
bencher.iter_batched(
setup,
|mut reader| async move {
|mut reader| {
for _ in 0..buf_len * 8 {
black_box(reader.read_huffman(code).await)?;
black_box(reader.read_huffman(code))?;
}
Ok::<_, Error>(())
},
Expand Down
Loading

0 comments on commit 6886411

Please sign in to comment.