From 2e2a0196661b77fdb2ff58c8e606aecaf7a08f40 Mon Sep 17 00:00:00 2001 From: Jakob Degen Date: Wed, 9 Feb 2022 19:53:04 -0500 Subject: [PATCH] Rewrite the multi cartesian product iterator to both simplify it and fix a bug. Rebased, fix a `;` that should have been `,`. Fix the test `correct_empty_multi_product` to ensure that is returns None after the first and only empty vector. --- src/adaptors/multi_product.rs | 273 ++++++++++++++-------------------- tests/quick.rs | 6 + 2 files changed, 119 insertions(+), 160 deletions(-) diff --git a/src/adaptors/multi_product.rs b/src/adaptors/multi_product.rs index ef7fadba8..a087d0f8b 100644 --- a/src/adaptors/multi_product.rs +++ b/src/adaptors/multi_product.rs @@ -1,7 +1,6 @@ #![cfg(feature = "use_alloc")] use crate::size_hint; -use crate::Itertools; use alloc::vec::Vec; @@ -14,18 +13,34 @@ use alloc::vec::Vec; /// See [`.multi_cartesian_product()`](crate::Itertools::multi_cartesian_product) /// for more information. #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] -pub struct MultiProduct(Vec>) +pub struct MultiProduct where I: Iterator + Clone, - I::Item: Clone; + I::Item: Clone, +{ + state: MultiProductState, + iters: Vec>, +} impl std::fmt::Debug for MultiProduct where I: Iterator + Clone + std::fmt::Debug, I::Item: Clone + std::fmt::Debug, { - debug_fmt_fields!(CoalesceBy, 0); + debug_fmt_fields!(CoalesceBy, iters); +} + +/// Stores the current state of the iterator. +#[derive(Clone)] +enum MultiProductState { + /// In the middle of an iteration. The `Vec` is the last value we returned + InProgress(Vec), + /// At the beginning of an iteration. The `Vec` is the next value to be returned. + Restarted(Vec), + /// Iteration has not been started + Unstarted, } +use MultiProductState::*; /// Create a new cartesian product iterator over an arbitrary number /// of iterators of the same type. @@ -38,11 +53,12 @@ where ::IntoIter: Clone, ::Item: Clone, { - MultiProduct( - iters + MultiProduct { + state: MultiProductState::Unstarted, + iters: iters .map(|i| MultiProductIter::new(i.into_iter())) .collect(), - ) + } } #[derive(Clone, Debug)] @@ -52,87 +68,10 @@ where I: Iterator + Clone, I::Item: Clone, { - cur: Option, iter: I, iter_orig: I, } -/// Holds the current state during an iteration of a `MultiProduct`. -#[derive(Debug)] -enum MultiProductIterState { - StartOfIter, - MidIter { on_first_iter: bool }, -} - -impl MultiProduct -where - I: Iterator + Clone, - I::Item: Clone, -{ - /// Iterates the rightmost iterator, then recursively iterates iterators - /// to the left if necessary. - /// - /// Returns true if the iteration succeeded, else false. - fn iterate_last( - multi_iters: &mut [MultiProductIter], - mut state: MultiProductIterState, - ) -> bool { - use self::MultiProductIterState::*; - - if let Some((last, rest)) = multi_iters.split_last_mut() { - let on_first_iter = match state { - StartOfIter => { - let on_first_iter = !last.in_progress(); - state = MidIter { on_first_iter }; - on_first_iter - } - MidIter { on_first_iter } => on_first_iter, - }; - - if !on_first_iter { - last.iterate(); - } - - if last.in_progress() { - true - } else if MultiProduct::iterate_last(rest, state) { - last.reset(); - last.iterate(); - // If iterator is None twice consecutively, then iterator is - // empty; whole product is empty. - last.in_progress() - } else { - false - } - } else { - // Reached end of iterator list. On initialisation, return true. - // At end of iteration (final iterator finishes), finish. - match state { - StartOfIter => false, - MidIter { on_first_iter } => on_first_iter, - } - } - } - - /// Returns the unwrapped value of the next iteration. - fn curr_iterator(&self) -> Vec { - self.0 - .iter() - .map(|multi_iter| multi_iter.cur.clone().unwrap()) - .collect() - } - - /// Returns true if iteration has started and has not yet finished; false - /// otherwise. - fn in_progress(&self) -> bool { - if let Some(last) = self.0.last() { - last.in_progress() - } else { - false - } - } -} - impl MultiProductIter where I: Iterator + Clone, @@ -140,26 +79,13 @@ where { fn new(iter: I) -> Self { MultiProductIter { - cur: None, iter: iter.clone(), iter_orig: iter, } } - /// Iterate the managed iterator. - fn iterate(&mut self) { - self.cur = self.iter.next(); - } - - /// Reset the managed iterator. - fn reset(&mut self) { - self.iter = self.iter_orig.clone(); - } - - /// Returns true if the current iterator has been started and has not yet - /// finished; false otherwise. - fn in_progress(&self) -> bool { - self.cur.is_some() + fn next(&mut self) -> Option { + self.iter.next() } } @@ -171,81 +97,108 @@ where type Item = Vec; fn next(&mut self) -> Option { - if MultiProduct::iterate_last(&mut self.0, MultiProductIterState::StartOfIter) { - Some(self.curr_iterator()) - } else { - None + let last = match &mut self.state { + InProgress(v) => v, + Restarted(v) => { + let v = core::mem::replace(v, Vec::new()); + self.state = InProgress(v.clone()); + return Some(v); + } + Unstarted => { + let next: Option> = self.iters.iter_mut().map(|i| i.next()).collect(); + if let Some(v) = &next { + self.state = InProgress(v.clone()); + } + return next; + } + }; + + // Starting from the last iterator, advance each iterator until we find one that returns a + // value. + for i in (0..self.iters.len()).rev() { + let iter = &mut self.iters[i]; + let loc = &mut last[i]; + if let Some(val) = iter.next() { + *loc = val; + return Some(last.clone()); + } else { + iter.iter = iter.iter_orig.clone(); + if let Some(val) = iter.next() { + *loc = val; + } else { + // This case should not really take place; we had an in progress iterator, reset + // it, and called `.next()`, but now its empty. In any case, the product is + // empty now and we should handle things accordingly. + self.state = Unstarted; + return None; + } + } } + + // Reaching here indicates that all the iterators returned none, and so iteration has completed + let v = core::mem::replace(last, Vec::new()); + self.state = Restarted(v); + None } fn count(self) -> usize { - if self.0.is_empty() { - return 0; - } - - if !self.in_progress() { - return self - .0 + // `remaining` is the number of remaining iterations before the current iterator is + // exhausted. `per_reset` is the number of total iterations that take place each time the + // current iterator is reset + let (remaining, per_reset) = + self.iters .into_iter() - .fold(1, |acc, multi_iter| acc * multi_iter.iter.count()); + .rev() + .fold((0, 1), |(remaining, per_reset), iter| { + let remaining = remaining + per_reset * iter.iter.count(); + let per_reset = per_reset * iter.iter_orig.count(); + (remaining, per_reset) + }); + if let Restarted(_) | Unstarted = &self.state { + per_reset + } else { + remaining } - - self.0.into_iter().fold( - 0, - |acc, - MultiProductIter { - iter, - iter_orig, - cur: _, - }| { - let total_count = iter_orig.count(); - let cur_count = iter.count(); - acc * total_count + cur_count - }, - ) } fn size_hint(&self) -> (usize, Option) { - // Not ExactSizeIterator because size may be larger than usize - if self.0.is_empty() { - return (0, Some(0)); - } - - if !self.in_progress() { - return self.0.iter().fold((1, Some(1)), |acc, multi_iter| { - size_hint::mul(acc, multi_iter.iter.size_hint()) - }); + let initial = ((0, Some(0)), (1, Some(1))); + // Exact same logic as for `count` + let (remaining, per_reset) = + self.iters + .iter() + .rev() + .fold(initial, |(remaining, per_reset), iter| { + let prod = size_hint::mul(per_reset, iter.iter.size_hint()); + let remaining = size_hint::add(remaining, prod); + let per_reset = size_hint::mul(per_reset, iter.iter_orig.size_hint()); + (remaining, per_reset) + }); + if let Restarted(_) | Unstarted = &self.state { + per_reset + } else { + remaining } - - self.0.iter().fold( - (0, Some(0)), - |acc, - &MultiProductIter { - ref iter, - ref iter_orig, - cur: _, - }| { - let cur_size = iter.size_hint(); - let total_size = iter_orig.size_hint(); - size_hint::add(size_hint::mul(acc, total_size), cur_size) - }, - ) } fn last(self) -> Option { - let iter_count = self.0.len(); - - let lasts: Self::Item = self - .0 - .into_iter() - .map(|multi_iter| multi_iter.iter.last()) - .while_some() - .collect(); - - if lasts.len() == iter_count { - Some(lasts) + // The way resetting works makes the first iterator a little bit special + let mut iter = self.iters.into_iter(); + if let Some(first) = iter.next() { + let first = if let Restarted(_) | Unstarted = &self.state { + first.iter_orig.last() + } else { + first.iter.last() + }; + core::iter::once(first) + .chain(iter.map(|sub| sub.iter_orig.last())) + .collect() } else { - None + if let Restarted(_) | Unstarted = &self.state { + Some(Vec::new()) + } else { + None + } } } } diff --git a/tests/quick.rs b/tests/quick.rs index 92d3f9f8e..b0da6bfa6 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -450,6 +450,12 @@ quickcheck! { assert_eq!(answer.into_iter().last(), a.multi_cartesian_product().last()); } + fn correct_empty_multi_product() -> () { + let empty = Vec::>::new().into_iter().multi_cartesian_product(); + assert!(correct_size_hint(empty.clone())); + itertools::assert_equal(empty, std::iter::once(Vec::new())) + } + #[allow(deprecated)] fn size_step(a: Iter, s: usize) -> bool { let mut s = s;