Skip to content

Commit

Permalink
Merge pull request #1305 from jonasBoss/axis_windows_dimension
Browse files Browse the repository at this point in the history
Change `NdProducer::Dim` of `axis_windows()` to `Ix1`
  • Loading branch information
bluss authored Aug 2, 2024
2 parents 45009ff + 21fb817 commit f163e14
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 41 deletions.
8 changes: 3 additions & 5 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use crate::iter::{
AxisChunksIterMut,
AxisIter,
AxisIterMut,
AxisWindows,
ExactChunks,
ExactChunksMut,
IndexedIter,
Expand Down Expand Up @@ -1521,7 +1522,7 @@ where
/// assert_eq!(window.shape(), &[4, 3, 2]);
/// }
/// ```
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> Windows<'_, A, D>
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D>
where S: Data
{
let axis_index = axis.index();
Expand All @@ -1537,10 +1538,7 @@ where
self.shape()
);

let mut size = self.raw_dim();
size[axis_index] = window_size;

Windows::new(self.view(), size)
AxisWindows::new(self.view(), axis, window_size)
}

// Return (length, stride) for diagonal
Expand Down
1 change: 1 addition & 0 deletions src/iterators/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub use crate::iterators::{
AxisChunksIterMut,
AxisIter,
AxisIterMut,
AxisWindows,
ExactChunks,
ExactChunksIter,
ExactChunksIterMut,
Expand Down
2 changes: 1 addition & 1 deletion src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use super::{Dimension, Ix, Ixs};
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
pub use self::into_iter::IntoIter;
pub use self::lanes::{Lanes, LanesMut};
pub use self::windows::Windows;
pub use self::windows::{AxisWindows, Windows};

use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};

Expand Down
199 changes: 164 additions & 35 deletions src/iterators/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,41 +41,7 @@ impl<'a, A, D: Dimension> Windows<'a, A, D>
let strides = axis_strides.into_dimension();
let window_strides = a.strides.clone();

ndassert!(
a.ndim() == window.ndim(),
concat!(
"Window dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
window.ndim(),
a.ndim(),
a.shape()
);

ndassert!(
a.ndim() == strides.ndim(),
concat!(
"Stride dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
strides.ndim(),
a.ndim(),
a.shape()
);

let mut base = a;
base.slice_each_axis_inplace(|ax_desc| {
let len = ax_desc.len;
let wsz = window[ax_desc.axis.index()];
let stride = strides[ax_desc.axis.index()];

if len < wsz {
Slice::new(0, Some(0), 1)
} else {
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
}
});

let base = build_base(a, window.clone(), strides);
Windows {
base: base.into_raw_view(),
life: PhantomData,
Expand Down Expand Up @@ -160,3 +126,166 @@ impl_iterator! {

send_sync_read_only!(Windows);
send_sync_read_only!(WindowsIter);

/// Window producer and iterable
///
/// See [`.axis_windows()`](ArrayBase::axis_windows) for more
/// information.
pub struct AxisWindows<'a, A, D>
{
base: ArrayView<'a, A, D>,
axis_idx: usize,
window: D,
strides: D,
}

impl<'a, A, D: Dimension> AxisWindows<'a, A, D>
{
pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
{
let window_strides = a.strides.clone();
let axis_idx = axis.index();

let mut window = a.raw_dim();
window[axis_idx] = window_size;

let ndim = window.ndim();
let mut unit_stride = D::zeros(ndim);
unit_stride.slice_mut().fill(1);

let base = build_base(a, window.clone(), unit_stride);
AxisWindows {
base,
axis_idx,
window,
strides: window_strides,
}
}
}

impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D>
{
type Item = ArrayView<'a, A, D>;
type Dim = Ix1;
type Ptr = *mut A;
type Stride = isize;

fn raw_dim(&self) -> Ix1
{
Ix1(self.base.raw_dim()[self.axis_idx])
}

fn layout(&self) -> Layout
{
self.base.layout()
}

fn as_ptr(&self) -> *mut A
{
self.base.as_ptr() as *mut _
}

fn contiguous_stride(&self) -> isize
{
self.base.contiguous_stride()
}

unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item
{
ArrayView::new_(ptr, self.window.clone(), self.strides.clone())
}

unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A
{
let mut d = D::zeros(self.base.ndim());
d[self.axis_idx] = i[0];
self.base.uget_ptr(&d)
}

fn stride_of(&self, axis: Axis) -> isize
{
assert_eq!(axis, Axis(0));
self.base.stride_of(Axis(self.axis_idx))
}

fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
{
assert_eq!(axis, Axis(0));
let (a, b) = self.base.split_at(Axis(self.axis_idx), index);
(
AxisWindows {
base: a,
axis_idx: self.axis_idx,
window: self.window.clone(),
strides: self.strides.clone(),
},
AxisWindows {
base: b,
axis_idx: self.axis_idx,
window: self.window,
strides: self.strides,
},
)
}

private_impl!{}
}

impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D>
where
D: Dimension,
A: 'a,
{
type Item = <Self::IntoIter as Iterator>::Item;
type IntoIter = WindowsIter<'a, A, D>;
fn into_iter(self) -> Self::IntoIter
{
WindowsIter {
iter: self.base.into_base_iter(),
life: PhantomData,
window: self.window,
strides: self.strides,
}
}
}

/// build the base array of the `Windows` and `AxisWindows` structs
fn build_base<A, D>(a: ArrayView<A, D>, window: D, strides: D) -> ArrayView<A, D>
where D: Dimension
{
ndassert!(
a.ndim() == window.ndim(),
concat!(
"Window dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
window.ndim(),
a.ndim(),
a.shape()
);

ndassert!(
a.ndim() == strides.ndim(),
concat!(
"Stride dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
strides.ndim(),
a.ndim(),
a.shape()
);

let mut base = a;
base.slice_each_axis_inplace(|ax_desc| {
let len = ax_desc.len;
let wsz = window[ax_desc.axis.index()];
let stride = strides[ax_desc.axis.index()];

if len < wsz {
Slice::new(0, Some(0), 1)
} else {
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
}
});
base
}
16 changes: 16 additions & 0 deletions tests/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,22 @@ fn test_axis_windows_3d()
]);
}

#[test]
fn tests_axis_windows_3d_zips_with_1d()
{
let a = Array::from_iter(0..27)
.into_shape_with_order((3, 3, 3))
.unwrap();
let mut b = Array::zeros(2);

Zip::from(b.view_mut())
.and(a.axis_windows(Axis(1), 2))
.for_each(|b, a| {
*b = a.sum();
});
assert_eq!(b,arr1(&[207, 261]));
}

#[test]
fn test_window_neg_stride()
{
Expand Down

0 comments on commit f163e14

Please sign in to comment.