From 82c86f4941c2c5ec99c8d97d30a736ea22f6a648 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 13 Dec 2023 15:53:35 +0100 Subject: [PATCH 1/2] Fix cubic spline when x isn't sorted --- lib/scholar/interpolation/cubic_spline.ex | 2 +- .../interpolation/cubic_spline_test.exs | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/lib/scholar/interpolation/cubic_spline.ex b/lib/scholar/interpolation/cubic_spline.ex index 00901f00..aa629d4e 100644 --- a/lib/scholar/interpolation/cubic_spline.ex +++ b/lib/scholar/interpolation/cubic_spline.ex @@ -88,12 +88,12 @@ defmodule Scholar.Interpolation.CubicSpline do "expected y to have shape #{inspect(x_shape)}, got: #{inspect(y_shape)}" end - dx = Nx.diff(x) sort_idx = Nx.argsort(x) x = Nx.take(x, sort_idx) y = Nx.take(y, sort_idx) + dx = Nx.diff(x) dy = Nx.diff(y) slope = dy / dx diff --git a/test/scholar/interpolation/cubic_spline_test.exs b/test/scholar/interpolation/cubic_spline_test.exs index b6d8077b..c2c6f12b 100644 --- a/test/scholar/interpolation/cubic_spline_test.exs +++ b/test/scholar/interpolation/cubic_spline_test.exs @@ -205,5 +205,59 @@ defmodule Scholar.Interpolation.CubicSplineTest do assert CubicSpline.predict(model, target_x) == target_x end end + + test "not sorted x" do + + x = Nx.tensor([3, 2, 4, 1, 0]) + y = Nx.tensor([-10, 3, -1, 2, 1]) + + model = CubicSpline.fit(x, y) + + # ensure given values are predicted accurately + # also ensures that the code works for scalar tensors + assert_all_close(CubicSpline.predict(model, 0), Nx.tensor(1.0)) + assert_all_close(CubicSpline.predict(model, 1), Nx.tensor(2.0)) + assert_all_close(CubicSpline.predict(model, 2), Nx.tensor(3.0)) + assert_all_close(CubicSpline.predict(model, 3), Nx.tensor(-10.0)) + assert_all_close(CubicSpline.predict(model, 4), Nx.tensor(-1.0)) + + # Test for continuity over the given point's boundaries + # (helps ensure no off-by-one's are happening when selecting polynomials) + assert_all_close( + CubicSpline.predict( + model, + Nx.tensor([-0.001, 0.001, 0.999, 1.001, 1.999, 2.001, 2.999, 3.001, 3.999, 4.001]) + ), + Nx.tensor([ + 1.0078465938568115, + 0.9921799302101135, + 1.9945827722549438, + 2.0054168701171875, + 3.0078206062316895, + 2.9921538829803467, + -9.989906311035156, + -10.010071754455566, + -1.0361297130584717, + -0.9638023376464844 + ]) + ) + + # ensure reference values are calculated accordingly + x_predict = Nx.tensor([-1, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5]) + + assert_all_close( + CubicSpline.predict(model, x_predict), + Nx.tensor([ + 26.50000762939453, + 8.78125286102295, + -0.15625077486038208, + 4.15625, + -3.21875, + -11.28125, + 26.90625, + 78.50000762939453 + ]) + ) + end end end From 1826609ec9e75a5b3726c1149a46af20ff2ed7ba Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 13 Dec 2023 16:00:28 +0100 Subject: [PATCH 2/2] Format --- lib/scholar/interpolation/cubic_spline.ex | 1 - test/scholar/interpolation/cubic_spline_test.exs | 1 - 2 files changed, 2 deletions(-) diff --git a/lib/scholar/interpolation/cubic_spline.ex b/lib/scholar/interpolation/cubic_spline.ex index aa629d4e..8f95a644 100644 --- a/lib/scholar/interpolation/cubic_spline.ex +++ b/lib/scholar/interpolation/cubic_spline.ex @@ -88,7 +88,6 @@ defmodule Scholar.Interpolation.CubicSpline do "expected y to have shape #{inspect(x_shape)}, got: #{inspect(y_shape)}" end - sort_idx = Nx.argsort(x) x = Nx.take(x, sort_idx) y = Nx.take(y, sort_idx) diff --git a/test/scholar/interpolation/cubic_spline_test.exs b/test/scholar/interpolation/cubic_spline_test.exs index c2c6f12b..c05ec9bf 100644 --- a/test/scholar/interpolation/cubic_spline_test.exs +++ b/test/scholar/interpolation/cubic_spline_test.exs @@ -207,7 +207,6 @@ defmodule Scholar.Interpolation.CubicSplineTest do end test "not sorted x" do - x = Nx.tensor([3, 2, 4, 1, 0]) y = Nx.tensor([-10, 3, -1, 2, 1])