From afa1b5b827549d30902c4e2044a4cc453821fb6d Mon Sep 17 00:00:00 2001 From: Folkert de Vries Date: Fri, 15 Nov 2024 13:54:56 +0100 Subject: [PATCH] Revert "load the inflate reader and writer to the stack" This reverts commit c26489305a7a68a24c397d93ced8277be7dabd88. --- zlib-rs/src/inflate.rs | 303 +++++++++++++++++++---------------------- 1 file changed, 138 insertions(+), 165 deletions(-) diff --git a/zlib-rs/src/inflate.rs b/zlib-rs/src/inflate.rs index c6056bd..d4c3ae2 100644 --- a/zlib-rs/src/inflate.rs +++ b/zlib-rs/src/inflate.rs @@ -490,6 +490,24 @@ impl<'a> State<'a> { } } +macro_rules! pull_byte { + ($self:expr) => { + match $self.bit_reader.pull_byte() { + Err(return_code) => return $self.inflate_leave(return_code), + Ok(_) => (), + } + }; +} + +macro_rules! need_bits { + ($self:expr, $n:expr) => { + match $self.bit_reader.need_bits($n) { + Err(return_code) => return $self.inflate_leave(return_code), + Ok(v) => v, + } + }; +} + // swaps endianness const fn zswap32(q: u32) -> u32 { u32::from_be(q.to_le()) @@ -499,53 +517,8 @@ const INFLATE_FAST_MIN_HAVE: usize = 15; const INFLATE_FAST_MIN_LEFT: usize = 260; impl<'a> State<'a> { - // NOTE: DO NOT RETURN FROM THIS FUNCTION! - // - // this function loads fields from `self` to the stack, and it is crucial that these fields - // are written back into `self` before exiting. - // - // Likewise, be careful with `self.foo()` function calls: the reader and writer are invalidated - // for the duration of this function! fn dispatch(&mut self) -> ReturnCode { - let mut writer; - let mut bit_reader; - - macro_rules! load { - ($state:expr) => { - bit_reader = core::mem::replace(&mut $state.bit_reader, BitReader::new(&[])); - writer = core::mem::replace(&mut $state.writer, Writer::new(&mut [])); - }; - } - - macro_rules! restore { - ($state:expr) => { - $state.bit_reader = bit_reader; - $state.writer = writer; - }; - } - - // load variables to the stack - load!(self); - - let return_code = 'label: loop { - macro_rules! pull_byte { - ($self:expr) => { - match bit_reader.pull_byte() { - Err(return_code) => break 'label return_code, - Ok(_) => (), - } - }; - } - - macro_rules! need_bits { - ($self:expr, $n:expr) => { - match bit_reader.need_bits($n) { - Err(return_code) => break 'label return_code, - Ok(v) => v, - } - }; - } - + 'label: loop { match self.mode { Mode::Head => { if self.wrap == 0 { @@ -557,15 +530,15 @@ impl<'a> State<'a> { need_bits!(self, 16); // Gzip - if (self.wrap & 2) != 0 && bit_reader.hold() == 0x8b1f { + if (self.wrap & 2) != 0 && self.bit_reader.hold() == 0x8b1f { if self.wbits == 0 { self.wbits = 15; } - let b0 = bit_reader.bits(8) as u8; - let b1 = (bit_reader.hold() >> 8) as u8; + let b0 = self.bit_reader.bits(8) as u8; + let b1 = (self.bit_reader.hold() >> 8) as u8; self.checksum = crc32(crate::CRC32_INITIAL_VALUE, &[b0, b1]); - bit_reader.init_bits(); + self.bit_reader.init_bits(); self.mode = Mode::Flags; @@ -578,19 +551,20 @@ impl<'a> State<'a> { // check if zlib header is allowed if (self.wrap & 1) == 0 - || ((bit_reader.bits(8) << 8) + (bit_reader.hold() >> 8)) % 31 != 0 + || ((self.bit_reader.bits(8) << 8) + (self.bit_reader.hold() >> 8)) % 31 + != 0 { self.mode = Mode::Bad; break 'label self.bad("incorrect header check\0"); } - if bit_reader.bits(4) != Z_DEFLATED as u64 { + if self.bit_reader.bits(4) != Z_DEFLATED as u64 { self.mode = Mode::Bad; break 'label self.bad("unknown compression method\0"); } - bit_reader.drop_bits(4); - let len = bit_reader.bits(4) as u8 + 8; + self.bit_reader.drop_bits(4); + let len = self.bit_reader.bits(4) as u8 + 8; if self.wbits == 0 { self.wbits = len; @@ -605,14 +579,14 @@ impl<'a> State<'a> { self.gzip_flags = 0; // indicate zlib header self.checksum = crate::ADLER32_INITIAL_VALUE as _; - if bit_reader.hold() & 0x200 != 0 { - bit_reader.init_bits(); + if self.bit_reader.hold() & 0x200 != 0 { + self.bit_reader.init_bits(); self.mode = Mode::DictId; continue 'label; } else { - bit_reader.init_bits(); + self.bit_reader.init_bits(); self.mode = Mode::Type; @@ -621,7 +595,7 @@ impl<'a> State<'a> { } Mode::Flags => { need_bits!(self, 16); - self.gzip_flags = bit_reader.hold() as i32; + self.gzip_flags = self.bit_reader.hold() as i32; // Z_DEFLATED = 8 is the only supported method if self.gzip_flags & 0xff != Z_DEFLATED { @@ -635,16 +609,16 @@ impl<'a> State<'a> { } if let Some(head) = self.head.as_mut() { - head.text = ((bit_reader.hold() >> 8) & 1) as i32; + head.text = ((self.bit_reader.hold() >> 8) & 1) as i32; } if (self.gzip_flags & 0x0200) != 0 && (self.wrap & 4) != 0 { - let b0 = bit_reader.bits(8) as u8; - let b1 = (bit_reader.hold() >> 8) as u8; + let b0 = self.bit_reader.bits(8) as u8; + let b1 = (self.bit_reader.hold() >> 8) as u8; self.checksum = crc32(self.checksum, &[b0, b1]); } - bit_reader.init_bits(); + self.bit_reader.init_bits(); self.mode = Mode::Time; continue 'label; @@ -652,15 +626,15 @@ impl<'a> State<'a> { Mode::Time => { need_bits!(self, 32); if let Some(head) = self.head.as_mut() { - head.time = bit_reader.hold() as z_size; + head.time = self.bit_reader.hold() as z_size; } if (self.gzip_flags & 0x0200) != 0 && (self.wrap & 4) != 0 { - let bytes = (bit_reader.hold() as u32).to_le_bytes(); + let bytes = (self.bit_reader.hold() as u32).to_le_bytes(); self.checksum = crc32(self.checksum, &bytes); } - bit_reader.init_bits(); + self.bit_reader.init_bits(); self.mode = Mode::Os; continue 'label; @@ -668,16 +642,16 @@ impl<'a> State<'a> { Mode::Os => { need_bits!(self, 16); if let Some(head) = self.head.as_mut() { - head.xflags = (bit_reader.hold() & 0xff) as i32; - head.os = (bit_reader.hold() >> 8) as i32; + head.xflags = (self.bit_reader.hold() & 0xff) as i32; + head.os = (self.bit_reader.hold() >> 8) as i32; } if (self.gzip_flags & 0x0200) != 0 && (self.wrap & 4) != 0 { - let bytes = (bit_reader.hold() as u16).to_le_bytes(); + let bytes = (self.bit_reader.hold() as u16).to_le_bytes(); self.checksum = crc32(self.checksum, &bytes); } - bit_reader.init_bits(); + self.bit_reader.init_bits(); self.mode = Mode::ExLen; continue 'label; @@ -687,16 +661,16 @@ impl<'a> State<'a> { need_bits!(self, 16); // self.length (and head.extra_len) represent the length of the extra field - self.length = bit_reader.hold() as usize; + self.length = self.bit_reader.hold() as usize; if let Some(head) = self.head.as_mut() { head.extra_len = self.length as u32; } if (self.gzip_flags & 0x0200) != 0 && (self.wrap & 4) != 0 { - let bytes = (bit_reader.hold() as u16).to_le_bytes(); + let bytes = (self.bit_reader.hold() as u16).to_le_bytes(); self.checksum = crc32(self.checksum, &bytes); } - bit_reader.init_bits(); + self.bit_reader.init_bits(); } else if let Some(head) = self.head.as_mut() { head.extra = core::ptr::null_mut(); } @@ -708,7 +682,8 @@ impl<'a> State<'a> { Mode::Extra => { if (self.gzip_flags & 0x0400) != 0 { // self.length is the number of remaining `extra` bytes. But they may not all be available - let extra_available = Ord::min(self.length, bit_reader.bytes_remaining()); + let extra_available = + Ord::min(self.length, self.bit_reader.bytes_remaining()); if extra_available > 0 { if let Some(head) = self.head.as_mut() { @@ -737,7 +712,7 @@ impl<'a> State<'a> { // and bit_reader.bytes_remaining(), so the count won't // go out of bounds. core::ptr::copy_nonoverlapping( - bit_reader.as_mut_ptr(), + self.bit_reader.as_mut_ptr(), head.extra.add(next_write_offset), count, ); @@ -747,12 +722,12 @@ impl<'a> State<'a> { // Checksum if (self.gzip_flags & 0x0200) != 0 && (self.wrap & 4) != 0 { - let extra_slice = &bit_reader.as_slice()[..extra_available]; + let extra_slice = &self.bit_reader.as_slice()[..extra_available]; self.checksum = crc32(self.checksum, extra_slice) } self.in_available -= extra_available; - bit_reader.advance(extra_available); + self.bit_reader.advance(extra_available); self.length -= extra_available; } @@ -775,7 +750,7 @@ impl<'a> State<'a> { // the name string will always be null-terminated, but might be longer than we have // space for in the header struct. Nonetheless, we read the whole thing. - let slice = bit_reader.as_slice(); + let slice = self.bit_reader.as_slice(); let null_terminator_index = slice.iter().position(|c| *c == 0); // we include the null terminator if it exists @@ -811,9 +786,9 @@ impl<'a> State<'a> { } let reached_end = name_slice.last() == Some(&0); - bit_reader.advance(name_slice.len()); + self.bit_reader.advance(name_slice.len()); - if !reached_end && bit_reader.bytes_remaining() == 0 { + if !reached_end && self.bit_reader.bytes_remaining() == 0 { break 'label self.inflate_leave(ReturnCode::Ok); } } else if let Some(head) = self.head.as_mut() { @@ -833,7 +808,7 @@ impl<'a> State<'a> { // the comment string will always be null-terminated, but might be longer than we have // space for in the header struct. Nonetheless, we read the whole thing. - let slice = bit_reader.as_slice(); + let slice = self.bit_reader.as_slice(); let null_terminator_index = slice.iter().position(|c| *c == 0); // we include the null terminator if it exists @@ -869,9 +844,9 @@ impl<'a> State<'a> { } let reached_end = comment_slice.last() == Some(&0); - bit_reader.advance(comment_slice.len()); + self.bit_reader.advance(comment_slice.len()); - if !reached_end && bit_reader.bytes_remaining() == 0 { + if !reached_end && self.bit_reader.bytes_remaining() == 0 { break 'label self.inflate_leave(ReturnCode::Ok); } } else if let Some(head) = self.head.as_mut() { @@ -887,13 +862,13 @@ impl<'a> State<'a> { need_bits!(self, 16); if (self.wrap & 4) != 0 - && bit_reader.hold() as u32 != (self.checksum & 0xffff) + && self.bit_reader.hold() as u32 != (self.checksum & 0xffff) { self.mode = Mode::Bad; break 'label self.bad("header crc mismatch\0"); } - bit_reader.init_bits(); + self.bit_reader.init_bits(); } if let Some(head) = self.head.as_mut() { @@ -925,23 +900,23 @@ impl<'a> State<'a> { } Mode::TypeDo => { if self.flags.contains(Flags::IS_LAST_BLOCK) { - bit_reader.next_byte_boundary(); + self.bit_reader.next_byte_boundary(); self.mode = Mode::Check; continue 'label; } need_bits!(self, 3); - // self.last = bit_reader.bits(1) != 0; + // self.last = self.bit_reader.bits(1) != 0; self.flags - .update(Flags::IS_LAST_BLOCK, bit_reader.bits(1) != 0); - bit_reader.drop_bits(1); + .update(Flags::IS_LAST_BLOCK, self.bit_reader.bits(1) != 0); + self.bit_reader.drop_bits(1); - match bit_reader.bits(2) { + match self.bit_reader.bits(2) { 0b00 => { // eprintln!("inflate: stored block (last = {last})"); - bit_reader.drop_bits(2); + self.bit_reader.drop_bits(2); self.mode = Mode::Stored; @@ -962,7 +937,7 @@ impl<'a> State<'a> { self.mode = Mode::Len_; - bit_reader.drop_bits(2); + self.bit_reader.drop_bits(2); if let InflateFlush::Trees = self.flush { break 'label self.inflate_leave(ReturnCode::Ok); @@ -973,7 +948,7 @@ impl<'a> State<'a> { 0b10 => { // eprintln!("inflate: dynamic codes block (last = {last})"); - bit_reader.drop_bits(2); + self.bit_reader.drop_bits(2); self.mode = Mode::Table; @@ -982,7 +957,7 @@ impl<'a> State<'a> { 0b11 => { // eprintln!("inflate: invalid block type"); - bit_reader.drop_bits(2); + self.bit_reader.drop_bits(2); self.mode = Mode::Bad; break 'label self.bad("invalid block type\0"); @@ -994,11 +969,11 @@ impl<'a> State<'a> { } } Mode::Stored => { - bit_reader.next_byte_boundary(); + self.bit_reader.next_byte_boundary(); need_bits!(self, 32); - let hold = bit_reader.bits(32) as u32; + let hold = self.bit_reader.bits(32) as u32; // eprintln!("hold {hold:#x}"); @@ -1010,7 +985,7 @@ impl<'a> State<'a> { self.length = hold as usize & 0xFFFF; // eprintln!("inflate: stored length {}", state.length); - bit_reader.init_bits(); + self.bit_reader.init_bits(); if let InflateFlush::Trees = self.flush { break 'label self.inflate_leave(ReturnCode::Ok); @@ -1028,15 +1003,15 @@ impl<'a> State<'a> { break; } - copy = Ord::min(copy, writer.remaining()); - copy = Ord::min(copy, bit_reader.bytes_remaining()); + copy = Ord::min(copy, self.writer.remaining()); + copy = Ord::min(copy, self.bit_reader.bytes_remaining()); if copy == 0 { break 'label self.inflate_leave(ReturnCode::Ok); } - writer.extend(&bit_reader.as_slice()[..copy]); - bit_reader.advance(copy); + self.writer.extend(&self.bit_reader.as_slice()[..copy]); + self.bit_reader.advance(copy); self.length -= copy; } @@ -1049,48 +1024,46 @@ impl<'a> State<'a> { if !cfg!(feature = "__internal-fuzz-disable-checksum") && self.wrap != 0 { need_bits!(self, 32); - self.total += writer.len(); + self.total += self.writer.len(); if self.wrap & 4 != 0 { if self.gzip_flags != 0 { - self.crc_fold.fold(writer.filled(), self.checksum); + self.crc_fold.fold(self.writer.filled(), self.checksum); self.checksum = self.crc_fold.finish(); } else { - self.checksum = adler32(self.checksum, writer.filled()); + self.checksum = adler32(self.checksum, self.writer.filled()); } } let given_checksum = if self.gzip_flags != 0 { - bit_reader.hold() as u32 + self.bit_reader.hold() as u32 } else { - zswap32(bit_reader.hold() as u32) + zswap32(self.bit_reader.hold() as u32) }; - self.out_available = writer.capacity() - writer.len(); + self.out_available = self.writer.capacity() - self.writer.len(); if self.wrap & 4 != 0 && given_checksum != self.checksum { self.mode = Mode::Bad; break 'label self.bad("incorrect data check\0"); } - bit_reader.init_bits(); + self.bit_reader.init_bits(); } self.mode = Mode::Length; continue 'label; } Mode::Len => { - let avail_in = bit_reader.bytes_remaining(); - let avail_out = writer.remaining(); + let avail_in = self.bit_reader.bytes_remaining(); + let avail_out = self.writer.remaining(); // INFLATE_FAST_MIN_LEFT is important. It makes sure there is at least 32 bytes of free // space available. This means for many SIMD operations we don't need to process a // remainder; we just copy blindly, and a later operation will overwrite the extra copied // bytes if avail_in >= INFLATE_FAST_MIN_HAVE && avail_out >= INFLATE_FAST_MIN_LEFT { - restore!(self); inflate_fast_help(self, 0); - load!(self); continue 'label; } @@ -1099,10 +1072,10 @@ impl<'a> State<'a> { // get a literal, length, or end-of-block code let mut here; loop { - let bits = bit_reader.bits(self.len_table.bits); + let bits = self.bit_reader.bits(self.len_table.bits); here = self.len_table_get(bits as usize); - if here.bits <= bit_reader.bits_in_buffer() { + if here.bits <= self.bit_reader.bits_in_buffer() { break; } @@ -1112,20 +1085,20 @@ impl<'a> State<'a> { if here.op != 0 && here.op & 0xf0 == 0 { let last = here; loop { - let bits = bit_reader.bits((last.bits + last.op) as usize) as u16; + let bits = self.bit_reader.bits((last.bits + last.op) as usize) as u16; here = self.len_table_get((last.val + (bits >> last.bits)) as usize); - if last.bits + here.bits <= bit_reader.bits_in_buffer() { + if last.bits + here.bits <= self.bit_reader.bits_in_buffer() { break; } pull_byte!(self); } - bit_reader.drop_bits(last.bits); + self.bit_reader.drop_bits(last.bits); self.back += last.bits as usize; } - bit_reader.drop_bits(here.bits); + self.bit_reader.drop_bits(here.bits); self.back += here.bits as usize; self.length = here.val as usize; @@ -1165,8 +1138,8 @@ impl<'a> State<'a> { // get extra bits, if any if extra != 0 { need_bits!(self, extra); - self.length += bit_reader.bits(extra) as usize; - bit_reader.drop_bits(extra as u8); + self.length += self.bit_reader.bits(extra) as usize; + self.bit_reader.drop_bits(extra as u8); self.back += extra; } @@ -1178,13 +1151,13 @@ impl<'a> State<'a> { continue 'label; } Mode::Lit => { - if writer.is_full() { + if self.writer.is_full() { #[cfg(all(test, feature = "std"))] - eprintln!("Ok: writer is full ({} bytes)", writer.capacity()); + eprintln!("Ok: writer is full ({} bytes)", self.writer.capacity()); break 'label self.inflate_leave(ReturnCode::Ok); } - writer.push(self.length as u8); + self.writer.push(self.length as u8); self.mode = Mode::Len; @@ -1194,9 +1167,9 @@ impl<'a> State<'a> { // get distance code let mut here; loop { - let bits = bit_reader.bits(self.dist_table.bits) as usize; + let bits = self.bit_reader.bits(self.dist_table.bits) as usize; here = self.dist_table_get(bits); - if here.bits <= bit_reader.bits_in_buffer() { + if here.bits <= self.bit_reader.bits_in_buffer() { break; } @@ -1207,22 +1180,22 @@ impl<'a> State<'a> { let last = here; loop { - let bits = bit_reader.bits((last.bits + last.op) as usize); + let bits = self.bit_reader.bits((last.bits + last.op) as usize); here = self .dist_table_get(last.val as usize + ((bits as usize) >> last.bits)); - if last.bits + here.bits <= bit_reader.bits_in_buffer() { + if last.bits + here.bits <= self.bit_reader.bits_in_buffer() { break; } pull_byte!(self); } - bit_reader.drop_bits(last.bits); + self.bit_reader.drop_bits(last.bits); self.back += last.bits as usize; } - bit_reader.drop_bits(here.bits); + self.bit_reader.drop_bits(here.bits); if here.op & 64 != 0 { self.mode = Mode::Bad; @@ -1241,8 +1214,8 @@ impl<'a> State<'a> { if extra > 0 { need_bits!(self, extra); - self.offset += bit_reader.bits(extra) as usize; - bit_reader.drop_bits(extra as u8); + self.offset += self.bit_reader.bits(extra) as usize; + self.bit_reader.drop_bits(extra as u8); self.back += extra; } @@ -1259,14 +1232,17 @@ impl<'a> State<'a> { } Mode::Match => { 'match_: loop { - if writer.is_full() { + if self.writer.is_full() { #[cfg(all(feature = "std", test))] - eprintln!("BufError: writer is full ({} bytes)", writer.capacity()); + eprintln!( + "BufError: writer is full ({} bytes)", + self.writer.capacity() + ); break 'label self.inflate_leave(ReturnCode::Ok); } - let left = writer.remaining(); - let copy = writer.len(); + let left = self.writer.remaining(); + let copy = self.writer.len(); let copy = if self.offset > copy { // copy from window to output @@ -1296,12 +1272,13 @@ impl<'a> State<'a> { copy = Ord::min(copy, self.length); copy = Ord::min(copy, left); - writer.extend_from_window(&self.window, from..from + copy); + self.writer + .extend_from_window(&self.window, from..from + copy); copy } else { let copy = Ord::min(self.length, left); - writer.copy_match(self.offset, copy); + self.writer.copy_match(self.offset, copy); copy }; @@ -1321,12 +1298,12 @@ impl<'a> State<'a> { Mode::Done => todo!(), Mode::Table => { need_bits!(self, 14); - self.nlen = bit_reader.bits(5) as usize + 257; - bit_reader.drop_bits(5); - self.ndist = bit_reader.bits(5) as usize + 1; - bit_reader.drop_bits(5); - self.ncode = bit_reader.bits(4) as usize + 4; - bit_reader.drop_bits(4); + self.nlen = self.bit_reader.bits(5) as usize + 257; + self.bit_reader.drop_bits(5); + self.ndist = self.bit_reader.bits(5) as usize + 1; + self.bit_reader.drop_bits(5); + self.ncode = self.bit_reader.bits(4) as usize + 4; + self.bit_reader.drop_bits(4); // TODO pkzit_bug_workaround if self.nlen > 286 || self.ndist > 30 { @@ -1347,9 +1324,9 @@ impl<'a> State<'a> { while self.have < self.ncode { need_bits!(self, 3); - self.lens[ORDER[self.have] as usize] = bit_reader.bits(3) as u16; + self.lens[ORDER[self.have] as usize] = self.bit_reader.bits(3) as u16; self.have += 1; - bit_reader.drop_bits(3); + self.bit_reader.drop_bits(3); } while self.have < 19 { @@ -1382,9 +1359,9 @@ impl<'a> State<'a> { Mode::CodeLens => { while self.have < self.nlen + self.ndist { let here = loop { - let bits = bit_reader.bits(self.len_table.bits); + let bits = self.bit_reader.bits(self.len_table.bits); let here = self.len_table_get(bits as usize); - if here.bits <= bit_reader.bits_in_buffer() { + if here.bits <= self.bit_reader.bits_in_buffer() { break here; } @@ -1395,21 +1372,21 @@ impl<'a> State<'a> { match here.val { 0..=15 => { - bit_reader.drop_bits(here_bits); + self.bit_reader.drop_bits(here_bits); self.lens[self.have] = here.val; self.have += 1; } 16 => { need_bits!(self, here_bits as usize + 2); - bit_reader.drop_bits(here_bits); + self.bit_reader.drop_bits(here_bits); if self.have == 0 { self.mode = Mode::Bad; break 'label self.bad("invalid bit length repeat\0"); } let len = self.lens[self.have - 1]; - let copy = 3 + bit_reader.bits(2) as usize; - bit_reader.drop_bits(2); + let copy = 3 + self.bit_reader.bits(2) as usize; + self.bit_reader.drop_bits(2); if self.have + copy > self.nlen + self.ndist { self.mode = Mode::Bad; @@ -1423,10 +1400,10 @@ impl<'a> State<'a> { } 17 => { need_bits!(self, here_bits as usize + 3); - bit_reader.drop_bits(here_bits); + self.bit_reader.drop_bits(here_bits); let len = 0; - let copy = 3 + bit_reader.bits(3) as usize; - bit_reader.drop_bits(3); + let copy = 3 + self.bit_reader.bits(3) as usize; + self.bit_reader.drop_bits(3); if self.have + copy > self.nlen + self.ndist { self.mode = Mode::Bad; @@ -1440,10 +1417,10 @@ impl<'a> State<'a> { } 18.. => { need_bits!(self, here_bits as usize + 7); - bit_reader.drop_bits(here_bits); + self.bit_reader.drop_bits(here_bits); let len = 0; - let copy = 11 + bit_reader.bits(7) as usize; - bit_reader.drop_bits(7); + let copy = 11 + self.bit_reader.bits(7) as usize; + self.bit_reader.drop_bits(7); if self.have + copy > self.nlen + self.ndist { self.mode = Mode::Bad; @@ -1522,9 +1499,9 @@ impl<'a> State<'a> { Mode::DictId => { need_bits!(self, 32); - self.checksum = zswap32(bit_reader.hold() as u32); + self.checksum = zswap32(self.bit_reader.hold() as u32); - bit_reader.init_bits(); + self.bit_reader.init_bits(); self.mode = Mode::Dict; @@ -1548,23 +1525,19 @@ impl<'a> State<'a> { // for gzip, last bytes contain LENGTH if self.wrap != 0 && self.gzip_flags != 0 { need_bits!(self, 32); - if (self.wrap & 4) != 0 && bit_reader.hold() != self.total as u64 { + if (self.wrap & 4) != 0 && self.bit_reader.hold() != self.total as u64 { self.mode = Mode::Bad; break 'label self.bad("incorrect length check\0"); } - bit_reader.init_bits(); + self.bit_reader.init_bits(); } // inflate stream terminated properly break 'label ReturnCode::StreamEnd; } }; - }; - - restore!(self); - - return_code + } } fn bad(&mut self, msg: &'static str) -> ReturnCode {