From 1a1d6cc0cfe39afd0efe50cb35f4249285c8ac38 Mon Sep 17 00:00:00 2001 From: Jakob Degen Date: Tue, 8 Feb 2022 16:13:22 -0500 Subject: [PATCH] Rewrite the multi cartesian product iterator to both simplify it and fix a bug. --- src/adaptors/multi_product.rs | 295 ++++++++++++++++------------------ tests/quick.rs | 6 + 2 files changed, 143 insertions(+), 158 deletions(-) diff --git a/src/adaptors/multi_product.rs b/src/adaptors/multi_product.rs index 30650eda6..70a5e29e1 100644 --- a/src/adaptors/multi_product.rs +++ b/src/adaptors/multi_product.rs @@ -1,7 +1,6 @@ #![cfg(feature = "use_alloc")] use crate::size_hint; -use crate::Itertools; use alloc::vec::Vec; @@ -14,217 +13,197 @@ use alloc::vec::Vec; /// See [`.multi_cartesian_product()`](crate::Itertools::multi_cartesian_product) /// for more information. #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] -pub struct MultiProduct(Vec>) - where I: Iterator + Clone, - I::Item: Clone; +pub struct MultiProduct +where + I: Iterator + Clone, + I::Item: Clone, +{ + // The last thing we returned + state: MultiProductState, + iters: Vec>, +} impl std::fmt::Debug for MultiProduct where I: Iterator + Clone + std::fmt::Debug, I::Item: Clone + std::fmt::Debug, { - debug_fmt_fields!(CoalesceBy, 0); + debug_fmt_fields!(CoalesceBy, iters); +} + +/// Stores the current state of the iterator. +#[derive(Clone)] +enum MultiProductState { + /// In the middle of an iteration. The `Vec` is the last value we returned + InProgress(Vec), + /// At the beginning of an iteration. The `Vec` is the next value to be returned. + Restarted(Vec), + /// Iteration has not been started + Unstarted, } +use MultiProductState::*; /// Create a new cartesian product iterator over an arbitrary number /// of iterators of the same type. /// /// Iterator element is of type `Vec`. pub fn multi_cartesian_product(iters: H) -> MultiProduct<::IntoIter> - where H: Iterator, - H::Item: IntoIterator, - ::IntoIter: Clone, - ::Item: Clone +where + H: Iterator, + H::Item: IntoIterator, + ::IntoIter: Clone, + ::Item: Clone, { - MultiProduct(iters.map(|i| MultiProductIter::new(i.into_iter())).collect()) + MultiProduct { + state: MultiProductState::Unstarted, + iters: iters + .map(|i| MultiProductIter::new(i.into_iter())) + .collect(), + } } #[derive(Clone, Debug)] /// Holds the state of a single iterator within a MultiProduct. struct MultiProductIter - where I: Iterator + Clone, - I::Item: Clone +where + I: Iterator + Clone, + I::Item: Clone, { - cur: Option, iter: I, iter_orig: I, } -/// Holds the current state during an iteration of a MultiProduct. -#[derive(Debug)] -enum MultiProductIterState { - StartOfIter, - MidIter { on_first_iter: bool }, -} - -impl MultiProduct - where I: Iterator + Clone, - I::Item: Clone +impl MultiProductIter +where + I: Iterator + Clone, + I::Item: Clone, { - /// Iterates the rightmost iterator, then recursively iterates iterators - /// to the left if necessary. - /// - /// Returns true if the iteration succeeded, else false. - fn iterate_last( - multi_iters: &mut [MultiProductIter], - mut state: MultiProductIterState - ) -> bool { - use self::MultiProductIterState::*; - - if let Some((last, rest)) = multi_iters.split_last_mut() { - let on_first_iter = match state { - StartOfIter => { - let on_first_iter = !last.in_progress(); - state = MidIter { on_first_iter }; - on_first_iter - }, - MidIter { on_first_iter } => on_first_iter - }; - - if !on_first_iter { - last.iterate(); - } - - if last.in_progress() { - true - } else if MultiProduct::iterate_last(rest, state) { - last.reset(); - last.iterate(); - // If iterator is None twice consecutively, then iterator is - // empty; whole product is empty. - last.in_progress() - } else { - false - } - } else { - // Reached end of iterator list. On initialisation, return true. - // At end of iteration (final iterator finishes), finish. - match state { - StartOfIter => false, - MidIter { on_first_iter } => on_first_iter - } - } - } - - /// Returns the unwrapped value of the next iteration. - fn curr_iterator(&self) -> Vec { - self.0.iter().map(|multi_iter| { - multi_iter.cur.clone().unwrap() - }).collect() - } - - /// Returns true if iteration has started and has not yet finished; false - /// otherwise. - fn in_progress(&self) -> bool { - if let Some(last) = self.0.last() { - last.in_progress() - } else { - false - } + fn reset(&mut self) { + self.iter = self.iter_orig.clone(); } -} -impl MultiProductIter - where I: Iterator + Clone, - I::Item: Clone -{ fn new(iter: I) -> Self { MultiProductIter { - cur: None, iter: iter.clone(), - iter_orig: iter + iter_orig: iter, } } - /// Iterate the managed iterator. - fn iterate(&mut self) { - self.cur = self.iter.next(); - } - - /// Reset the managed iterator. - fn reset(&mut self) { - self.iter = self.iter_orig.clone(); - } - - /// Returns true if the current iterator has been started and has not yet - /// finished; false otherwise. - fn in_progress(&self) -> bool { - self.cur.is_some() + fn next(&mut self) -> Option { + self.iter.next() } } impl Iterator for MultiProduct - where I: Iterator + Clone, - I::Item: Clone +where + I: Iterator + Clone, + I::Item: Clone, { type Item = Vec; fn next(&mut self) -> Option { - if MultiProduct::iterate_last( - &mut self.0, - MultiProductIterState::StartOfIter - ) { - Some(self.curr_iterator()) - } else { - None + let last = match &mut self.state { + InProgress(v) => v, + Restarted(v) => { + let v = core::mem::take(v); + self.state = InProgress(v.clone()); + return Some(v); + } + Unstarted => { + let next: Option> = self.iters.iter_mut().map(|i| i.next()).collect(); + if let Some(v) = &next { + self.state = InProgress(v.clone()); + } + return next; + } + }; + + // Starting from the last iterator, advance each iterator until we find one that returns a + // value. + for i in (0..self.iters.len()).rev() { + let iter = &mut self.iters[i]; + let loc = &mut last[i]; + if let Some(val) = iter.next() { + *loc = val; + return Some(last.clone()); + } else { + iter.reset(); + if let Some(val) = iter.next() { + *loc = val; + } else { + // This case should not really take place; we had an in progress iterator, reset + // it, and called `.next()`, but now its empty. In any case, the product is + // empty now and we should handle things accordingly. + self.state = Unstarted; + return None; + } + } } + + // Reaching here indicates that all the iterators returned none, and so iteration has completed + let v = core::mem::take(last); + self.state = Restarted(v); + None } fn count(self) -> usize { - if self.0.is_empty() { - return 0; - } - - if !self.in_progress() { - return self.0.into_iter().fold(1, |acc, multi_iter| { - acc * multi_iter.iter.count() - }); + // `remaining` is the number of remaining iterations before the current iterator is + // exhausted. `per_reset` is the number of total iterations that take place each time the + // current iterator is reset + let (remaining, per_reset) = + self.iters + .into_iter() + .rev() + .fold((0, 1), |(remaining, per_reset), iter| { + let remaining = remaining + per_reset * iter.iter.count(); + let per_reset = per_reset * iter.iter_orig.count(); + (remaining, per_reset) + }); + if let Restarted(_) | Unstarted = &self.state { + per_reset + } else { + remaining } - - self.0.into_iter().fold( - 0, - |acc, MultiProductIter { iter, iter_orig, cur: _ }| { - let total_count = iter_orig.count(); - let cur_count = iter.count(); - acc * total_count + cur_count - } - ) } fn size_hint(&self) -> (usize, Option) { - // Not ExactSizeIterator because size may be larger than usize - if self.0.is_empty() { - return (0, Some(0)); - } - - if !self.in_progress() { - return self.0.iter().fold((1, Some(1)), |acc, multi_iter| { - size_hint::mul(acc, multi_iter.iter.size_hint()) - }); + let initial = ((0, Some(0)), (1, Some(1))); + // Exact same logic as for `count` + let (remaining, per_reset) = + self.iters + .iter() + .rev() + .fold(initial, |(remaining, per_reset), iter| { + let prod = size_hint::mul(per_reset, iter.iter.size_hint()); + let remaining = size_hint::add(remaining, prod); + let per_reset = size_hint::mul(per_reset, iter.iter_orig.size_hint()); + (remaining, per_reset) + }); + if let Restarted(_) | Unstarted = &self.state { + per_reset + } else { + remaining } - - self.0.iter().fold( - (0, Some(0)), - |acc, &MultiProductIter { ref iter, ref iter_orig, cur: _ }| { - let cur_size = iter.size_hint(); - let total_size = iter_orig.size_hint(); - size_hint::add(size_hint::mul(acc, total_size), cur_size) - } - ) } fn last(self) -> Option { - let iter_count = self.0.len(); - - let lasts: Self::Item = self.0.into_iter() - .map(|multi_iter| multi_iter.iter.last()) - .while_some() - .collect(); - - if lasts.len() == iter_count { - Some(lasts) + // The way resetting works makes the first iterator a little bit special + let mut iter = self.iters.into_iter(); + if let Some(first) = iter.next() { + let first = if let Restarted(_) | Unstarted = &self.state { + first.iter_orig.last() + } else { + first.iter.last() + }; + core::iter::once(first) + .chain(iter.map(|sub| sub.iter_orig.last())) + .collect() } else { - None + if let Restarted(_) | Unstarted = &self.state { + Some(Vec::new()) + } else { + None + } } } } diff --git a/tests/quick.rs b/tests/quick.rs index 7e222a641..cf93c27b3 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -441,6 +441,12 @@ quickcheck! { assert_eq!(answer.into_iter().last(), a.clone().multi_cartesian_product().last()); } + fn correct_empty_multi_product() -> () { + let mut empty = Vec::>::new().into_iter().multi_cartesian_product(); + assert!(correct_size_hint(empty.clone())); + assert_eq!(empty.next(), Some(Vec::new())) + } + #[allow(deprecated)] fn size_step(a: Iter, s: usize) -> bool { let mut s = s;