diff --git a/lib/scholar/interpolation/cubic_spline.ex b/lib/scholar/interpolation/cubic_spline.ex index 00901f00..8f95a644 100644 --- a/lib/scholar/interpolation/cubic_spline.ex +++ b/lib/scholar/interpolation/cubic_spline.ex @@ -88,12 +88,11 @@ defmodule Scholar.Interpolation.CubicSpline do "expected y to have shape #{inspect(x_shape)}, got: #{inspect(y_shape)}" end - dx = Nx.diff(x) - sort_idx = Nx.argsort(x) x = Nx.take(x, sort_idx) y = Nx.take(y, sort_idx) + dx = Nx.diff(x) dy = Nx.diff(y) slope = dy / dx diff --git a/test/scholar/interpolation/cubic_spline_test.exs b/test/scholar/interpolation/cubic_spline_test.exs index b6d8077b..c05ec9bf 100644 --- a/test/scholar/interpolation/cubic_spline_test.exs +++ b/test/scholar/interpolation/cubic_spline_test.exs @@ -205,5 +205,58 @@ defmodule Scholar.Interpolation.CubicSplineTest do assert CubicSpline.predict(model, target_x) == target_x end end + + test "not sorted x" do + x = Nx.tensor([3, 2, 4, 1, 0]) + y = Nx.tensor([-10, 3, -1, 2, 1]) + + model = CubicSpline.fit(x, y) + + # ensure given values are predicted accurately + # also ensures that the code works for scalar tensors + assert_all_close(CubicSpline.predict(model, 0), Nx.tensor(1.0)) + assert_all_close(CubicSpline.predict(model, 1), Nx.tensor(2.0)) + assert_all_close(CubicSpline.predict(model, 2), Nx.tensor(3.0)) + assert_all_close(CubicSpline.predict(model, 3), Nx.tensor(-10.0)) + assert_all_close(CubicSpline.predict(model, 4), Nx.tensor(-1.0)) + + # Test for continuity over the given point's boundaries + # (helps ensure no off-by-one's are happening when selecting polynomials) + assert_all_close( + CubicSpline.predict( + model, + Nx.tensor([-0.001, 0.001, 0.999, 1.001, 1.999, 2.001, 2.999, 3.001, 3.999, 4.001]) + ), + Nx.tensor([ + 1.0078465938568115, + 0.9921799302101135, + 1.9945827722549438, + 2.0054168701171875, + 3.0078206062316895, + 2.9921538829803467, + -9.989906311035156, + -10.010071754455566, + -1.0361297130584717, + -0.9638023376464844 + ]) + ) + + # ensure reference values are calculated accordingly + x_predict = Nx.tensor([-1, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5]) + + assert_all_close( + CubicSpline.predict(model, x_predict), + Nx.tensor([ + 26.50000762939453, + 8.78125286102295, + -0.15625077486038208, + 4.15625, + -3.21875, + -11.28125, + 26.90625, + 78.50000762939453 + ]) + ) + end end end