diff --git a/src/interp1d/strategies/cubic_spline.rs b/src/interp1d/strategies/cubic_spline.rs index b8c5437..58fe474 100644 --- a/src/interp1d/strategies/cubic_spline.rs +++ b/src/interp1d/strategies/cubic_spline.rs @@ -287,8 +287,14 @@ where BoundaryCondition::Clamped => Self::solve_for_k(kv, x, data, RowBoundary::Clamped), BoundaryCondition::NotAKnot => Self::solve_for_k(kv, x, data, RowBoundary::NotAKnot), BoundaryCondition::Individual(bounds) => { - assert!(kv.raw_dim().remove_axis(AX0) == bounds.raw_dim().remove_axis(AX0)); // TODO: return error - assert!(bounds.raw_dim()[0] == 1); // TODO: return error + let mut bounds_shape = kv.raw_dim(); + bounds_shape[0] = 1; + if bounds_shape != bounds.raw_dim() { + return Err(BuilderError::ShapeError(format!( + "Boundary conditions array has wrong shape. Expected: {bounds_shape:?}, got: {:?}", + bounds.raw_dim() + ))); + } Self::solve_for_k_individual( kv.into_dyn(), x, diff --git a/tests/cubic_spline_strat.rs b/tests/cubic_spline_strat.rs index 59eea88..1b567fd 100644 --- a/tests/cubic_spline_strat.rs +++ b/tests/cubic_spline_strat.rs @@ -409,3 +409,32 @@ fn extrapolate_deriv2() { ]; assert_relative_eq!(res, expect, epsilon = f64::EPSILON, max_relative = 0.001); } + +#[test] +#[should_panic(expected = "Expected: [1, 2], got: [1, 3]")] +fn bounds_shape_error1() { + let y = array![[0.5, 1.0], [0.0, 1.5], [3.0, 0.5],]; + let boundaries = BoundaryCondition::Individual(array![[ + RowBoundary::Natural, + RowBoundary::Periodic, + RowBoundary::NotAKnot + ],]); + Interp1DBuilder::new(y) + .strategy(CubicSpline::new().boundary(boundaries)) + .build() + .unwrap(); +} + +#[test] +#[should_panic(expected = "Expected: [1, 2], got: [2, 2]")] +fn bounds_shape_error2() { + let y = array![[0.5, 1.0], [0.0, 1.5], [3.0, 0.5],]; + let boundaries = BoundaryCondition::Individual(array![ + [RowBoundary::Natural, RowBoundary::NotAKnot], + [RowBoundary::Natural, RowBoundary::NotAKnot], + ]); + Interp1DBuilder::new(y) + .strategy(CubicSpline::new().boundary(boundaries)) + .build() + .unwrap(); +}