diff --git a/src/adaptors/multi_product.rs b/src/adaptors/multi_product.rs index 74cb06d0e..d551c0afc 100644 --- a/src/adaptors/multi_product.rs +++ b/src/adaptors/multi_product.rs @@ -1,30 +1,52 @@ #![cfg(feature = "use_alloc")] - -use crate::size_hint; -use crate::Itertools; +use Option::{self as State, None as ProductEnded, Some as ProductInProgress}; +use Option::{self as CurrentItems, None as NotYetPopulated, Some as Populated}; use alloc::vec::Vec; +use crate::size_hint; + #[derive(Clone)] /// An iterator adaptor that iterates over the cartesian product of /// multiple iterators of type `I`. /// -/// An iterator element type is `Vec`. +/// An iterator element type is `Vec`. /// /// See [`.multi_cartesian_product()`](crate::Itertools::multi_cartesian_product) /// for more information. #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] -pub struct MultiProduct(Vec>) +pub struct MultiProduct(State>) where I: Iterator + Clone, I::Item: Clone; +#[derive(Clone)] +/// Internals for `MultiProduct`. +struct MultiProductInner +where + I: Iterator + Clone, + I::Item: Clone, +{ + /// Holds the iterators. + iters: Vec>, + /// Not populated at the beginning then it holds the current item of each iterator. + cur: CurrentItems>, +} + impl std::fmt::Debug for MultiProduct where I: Iterator + Clone + std::fmt::Debug, I::Item: Clone + std::fmt::Debug, { - debug_fmt_fields!(CoalesceBy, 0); + debug_fmt_fields!(MultiProduct, 0); +} + +impl std::fmt::Debug for MultiProductInner +where + I: Iterator + Clone + std::fmt::Debug, + I::Item: Clone + std::fmt::Debug, +{ + debug_fmt_fields!(MultiProductInner, iters, cur); } /// Create a new cartesian product iterator over an arbitrary number @@ -38,11 +60,13 @@ where ::IntoIter: Clone, ::Item: Clone, { - MultiProduct( - iters + let inner = MultiProductInner { + iters: iters .map(|i| MultiProductIter::new(i.into_iter())) .collect(), - ) + cur: NotYetPopulated, + }; + MultiProduct(ProductInProgress(inner)) } #[derive(Clone, Debug)] @@ -52,87 +76,10 @@ where I: Iterator + Clone, I::Item: Clone, { - cur: Option, iter: I, iter_orig: I, } -/// Holds the current state during an iteration of a `MultiProduct`. -#[derive(Debug)] -enum MultiProductIterState { - StartOfIter, - MidIter { on_first_iter: bool }, -} - -impl MultiProduct -where - I: Iterator + Clone, - I::Item: Clone, -{ - /// Iterates the rightmost iterator, then recursively iterates iterators - /// to the left if necessary. - /// - /// Returns true if the iteration succeeded, else false. - fn iterate_last( - multi_iters: &mut [MultiProductIter], - mut state: MultiProductIterState, - ) -> bool { - use self::MultiProductIterState::*; - - if let Some((last, rest)) = multi_iters.split_last_mut() { - let on_first_iter = match state { - StartOfIter => { - let on_first_iter = !last.in_progress(); - state = MidIter { on_first_iter }; - on_first_iter - } - MidIter { on_first_iter } => on_first_iter, - }; - - if !on_first_iter { - last.iterate(); - } - - if last.in_progress() { - true - } else if Self::iterate_last(rest, state) { - last.reset(); - last.iterate(); - // If iterator is None twice consecutively, then iterator is - // empty; whole product is empty. - last.in_progress() - } else { - false - } - } else { - // Reached end of iterator list. On initialisation, return true. - // At end of iteration (final iterator finishes), finish. - match state { - StartOfIter => false, - MidIter { on_first_iter } => on_first_iter, - } - } - } - - /// Returns the unwrapped value of the next iteration. - fn curr_iterator(&self) -> Vec { - self.0 - .iter() - .map(|multi_iter| multi_iter.cur.clone().unwrap()) - .collect() - } - - /// Returns true if iteration has started and has not yet finished; false - /// otherwise. - fn in_progress(&self) -> bool { - if let Some(last) = self.0.last() { - last.in_progress() - } else { - false - } - } -} - impl MultiProductIter where I: Iterator + Clone, @@ -140,27 +87,10 @@ where { fn new(iter: I) -> Self { Self { - cur: None, iter: iter.clone(), iter_orig: iter, } } - - /// Iterate the managed iterator. - fn iterate(&mut self) { - self.cur = self.iter.next(); - } - - /// Reset the managed iterator. - fn reset(&mut self) { - self.iter = self.iter_orig.clone(); - } - - /// Returns true if the current iterator has been started and has not yet - /// finished; false otherwise. - fn in_progress(&self) -> bool { - self.cur.is_some() - } } impl Iterator for MultiProduct @@ -171,81 +101,131 @@ where type Item = Vec; fn next(&mut self) -> Option { - if Self::iterate_last(&mut self.0, MultiProductIterState::StartOfIter) { - Some(self.curr_iterator()) - } else { - None + // This fuses the iterator. + let inner = self.0.as_mut()?; + match &mut inner.cur { + Populated(values) => { + debug_assert!(!inner.iters.is_empty()); + // Find (from the right) a non-finished iterator and + // reset the finished ones encountered. + for (iter, item) in inner.iters.iter_mut().zip(values.iter_mut()).rev() { + if let Some(new) = iter.iter.next() { + *item = new; + return Some(values.clone()); + } else { + iter.iter = iter.iter_orig.clone(); + // `cur` is populated so the untouched `iter_orig` can not be empty. + *item = iter.iter.next().unwrap(); + } + } + self.0 = ProductEnded; + None + } + // Only the first time. + NotYetPopulated => { + let next: Option> = inner.iters.iter_mut().map(|i| i.iter.next()).collect(); + if next.is_none() || inner.iters.is_empty() { + // This cartesian product had at most one item to generate and now ends. + self.0 = ProductEnded; + } else { + inner.cur = next.clone(); + } + next + } } } fn count(self) -> usize { - if self.0.is_empty() { - return 0; - } - - if !self.in_progress() { - return self - .0 + match self.0 { + ProductEnded => 0, + // The iterator is fresh so the count is the product of the length of each iterator: + // - If one of them is empty, stop counting. + // - Less `count()` calls than the general case. + ProductInProgress(MultiProductInner { + iters, + cur: NotYetPopulated, + }) => iters .into_iter() - .fold(1, |acc, multi_iter| acc * multi_iter.iter.count()); + .map(|iter| iter.iter_orig.count()) + .try_fold(1, |product, count| { + if count == 0 { + None + } else { + Some(product * count) + } + }) + .unwrap_or_default(), + // The general case. + ProductInProgress(MultiProductInner { + iters, + cur: Populated(_), + }) => iters.into_iter().fold(0, |mut acc, iter| { + if acc != 0 { + acc *= iter.iter_orig.count(); + } + acc + iter.iter.count() + }), } - - self.0.into_iter().fold( - 0, - |acc, - MultiProductIter { - iter, - iter_orig, - cur: _, - }| { - let total_count = iter_orig.count(); - let cur_count = iter.count(); - acc * total_count + cur_count - }, - ) } fn size_hint(&self) -> (usize, Option) { - // Not ExactSizeIterator because size may be larger than usize - if self.0.is_empty() { - return (0, Some(0)); - } - - if !self.in_progress() { - return self.0.iter().fold((1, Some(1)), |acc, multi_iter| { - size_hint::mul(acc, multi_iter.iter.size_hint()) - }); + match &self.0 { + ProductEnded => (0, Some(0)), + ProductInProgress(MultiProductInner { + iters, + cur: NotYetPopulated, + }) => iters + .iter() + .map(|iter| iter.iter_orig.size_hint()) + .fold((1, Some(1)), size_hint::mul), + ProductInProgress(MultiProductInner { + iters, + cur: Populated(_), + }) => { + if let [first, tail @ ..] = &iters[..] { + tail.iter().fold(first.iter.size_hint(), |mut sh, iter| { + sh = size_hint::mul(sh, iter.iter_orig.size_hint()); + size_hint::add(sh, iter.iter.size_hint()) + }) + } else { + // Since it is populated, this cartesian product has started so `iters` is not empty. + unreachable!() + } + } } - - self.0.iter().fold( - (0, Some(0)), - |acc, - MultiProductIter { - iter, - iter_orig, - cur: _, - }| { - let cur_size = iter.size_hint(); - let total_size = iter_orig.size_hint(); - size_hint::add(size_hint::mul(acc, total_size), cur_size) - }, - ) } fn last(self) -> Option { - let iter_count = self.0.len(); - - let lasts: Self::Item = self - .0 - .into_iter() - .map(|multi_iter| multi_iter.iter.last()) - .while_some() - .collect(); - - if lasts.len() == iter_count { - Some(lasts) + let MultiProductInner { iters, cur } = self.0?; + // Collect the last item of each iterator of the product. + if let Populated(values) = cur { + let mut count = iters.len(); + let last = iters + .into_iter() + .zip(values) + .map(|(i, value)| { + i.iter.last().unwrap_or_else(|| { + // The iterator is empty, use its current `value`. + count -= 1; + value + }) + }) + .collect(); + if count == 0 { + // `values` was the last item. + None + } else { + Some(last) + } } else { - None + iters.into_iter().map(|i| i.iter.last()).collect() } } } + +impl std::iter::FusedIterator for MultiProduct +where + I: Iterator + Clone, + I::Item: Clone, +{ +} diff --git a/src/lib.rs b/src/lib.rs index a9977e1f1..126eb2221 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1162,10 +1162,11 @@ pub trait Itertools: Iterator { /// the product of iterators yielding multiple types, use the /// [`iproduct`] macro instead. /// - /// /// The iterator element type is `Vec`, where `T` is the iterator element /// of the subiterators. /// + /// Note that the iterator is fused. + /// /// ``` /// use itertools::Itertools; /// let mut multi_prod = (0..3).map(|i| (i * 2)..(i * 2 + 2)) diff --git a/tests/quick.rs b/tests/quick.rs index dffcd22f6..954655a43 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -450,6 +450,12 @@ quickcheck! { assert_eq!(answer.into_iter().last(), a.multi_cartesian_product().last()); } + fn correct_empty_multi_product() -> () { + let empty = Vec::>::new().into_iter().multi_cartesian_product(); + assert!(correct_size_hint(empty.clone())); + itertools::assert_equal(empty, std::iter::once(Vec::new())) + } + #[allow(deprecated)] fn size_step(a: Iter, s: usize) -> bool { let mut s = s; diff --git a/tests/specializations.rs b/tests/specializations.rs index e14b1b669..fd8801e4e 100644 --- a/tests/specializations.rs +++ b/tests/specializations.rs @@ -163,7 +163,6 @@ quickcheck! { TestResult::passed() } - #[ignore] // It currently fails because `MultiProduct` is not fused. fn multi_cartesian_product(a: Vec, b: Vec, c: Vec) -> TestResult { if a.len() * b.len() * c.len() > 100 { return TestResult::discard();