Skip to content

Commit

Permalink
return result from solve_for_k and solve_for_k_individual fn's
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasBoss committed Dec 12, 2023
1 parent c20d2a7 commit 9d21f2f
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/interp1d/strategies/cubic_spline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::{
};

use ndarray::{
s, Array, Array1, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dimension, Ix1, IxDyn,
RemoveAxis, ScalarOperand, Zip,
s, Array, Array1, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dimension, FoldWhile, Ix1,
IxDyn, RemoveAxis, ScalarOperand, Zip,
};
use num_traits::{cast, Num, NumCast, Pow};

Expand Down Expand Up @@ -300,9 +300,9 @@ where
x,
data.view().into_dyn(),
bounds.view().into_dyn(),
);
)
}
};
}?;

let mut a_b_dim = data.raw_dim();
a_b_dim[0] -= 1;
Expand All @@ -329,15 +329,22 @@ where
x: &ArrayBase<Sx, Ix1>,
data: ArrayView<T, IxDyn>,
boundary: ArrayView<RowBoundary<T>, IxDyn>,
) where
) -> Result<(), BuilderError>
where
Sx: Data<Elem = T>,
{
if k.ndim() > 1 {
let ax = Axis(k.ndim() - 1);
Zip::from(k.axis_iter_mut(ax))
.and(data.axis_iter(ax))
.and(boundary.axis_iter(ax))
.for_each(|k, data, boundary| Self::solve_for_k_individual(k, x, data, boundary))
.fold_while(Ok(()), |_, k, data, boundary| {
Self::solve_for_k_individual(k, &x, data, boundary).map_or_else(
|err| FoldWhile::Done(Err(err)),
|_| FoldWhile::Continue(Ok(())),
)
})
.into_inner()
} else {
Self::solve_for_k(
k,
Expand All @@ -357,7 +364,8 @@ where
x: &ArrayBase<Sx, Ix1>,
data: &ArrayBase<Sd, _D>,
boundary: RowBoundary<T>,
) where
) -> Result<(), BuilderError>
where
_D: Dimension + RemoveAxis,
Sd: Data<Elem = T>,
Sx: Data<Elem = T>,
Expand Down Expand Up @@ -417,7 +425,9 @@ where

// apply boundary conditions
match (boundary.specialize(), len) {
(RowBoundary::Periodic, _) => todo!(),
(RowBoundary::Periodic, _) => {
todo!()
}
(RowBoundary::Clamped, _) => unreachable!(),
(RowBoundary::Natural, _) => unreachable!(),
(RowBoundary::NotAKnot, _) => unreachable!(),
Expand Down Expand Up @@ -533,6 +543,7 @@ where
}
}
Self::thomas(k, a_up, a_mid, a_low, rhs);
Ok(())
}

/// The Thomas algorithm is used, because the matrix A will be tridiagonal and diagonally dominant
Expand Down

0 comments on commit 9d21f2f

Please sign in to comment.