Skip to content

Commit

Permalink
Fix cubic spline when x isn't sorted (#219)
Browse files Browse the repository at this point in the history
* Fix cubic spline when x isn't sorted

* Format
  • Loading branch information
msluszniak authored Dec 14, 2023
1 parent 6438349 commit 99a5286
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
3 changes: 1 addition & 2 deletions lib/scholar/interpolation/cubic_spline.ex
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,11 @@ defmodule Scholar.Interpolation.CubicSpline do
"expected y to have shape #{inspect(x_shape)}, got: #{inspect(y_shape)}"
end

dx = Nx.diff(x)

sort_idx = Nx.argsort(x)
x = Nx.take(x, sort_idx)
y = Nx.take(y, sort_idx)

dx = Nx.diff(x)
dy = Nx.diff(y)

slope = dy / dx
Expand Down
53 changes: 53 additions & 0 deletions test/scholar/interpolation/cubic_spline_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,58 @@ defmodule Scholar.Interpolation.CubicSplineTest do
assert CubicSpline.predict(model, target_x) == target_x
end
end

test "not sorted x" do
x = Nx.tensor([3, 2, 4, 1, 0])
y = Nx.tensor([-10, 3, -1, 2, 1])

model = CubicSpline.fit(x, y)

# ensure given values are predicted accurately
# also ensures that the code works for scalar tensors
assert_all_close(CubicSpline.predict(model, 0), Nx.tensor(1.0))
assert_all_close(CubicSpline.predict(model, 1), Nx.tensor(2.0))
assert_all_close(CubicSpline.predict(model, 2), Nx.tensor(3.0))
assert_all_close(CubicSpline.predict(model, 3), Nx.tensor(-10.0))
assert_all_close(CubicSpline.predict(model, 4), Nx.tensor(-1.0))

# Test for continuity over the given point's boundaries
# (helps ensure no off-by-one's are happening when selecting polynomials)
assert_all_close(
CubicSpline.predict(
model,
Nx.tensor([-0.001, 0.001, 0.999, 1.001, 1.999, 2.001, 2.999, 3.001, 3.999, 4.001])
),
Nx.tensor([
1.0078465938568115,
0.9921799302101135,
1.9945827722549438,
2.0054168701171875,
3.0078206062316895,
2.9921538829803467,
-9.989906311035156,
-10.010071754455566,
-1.0361297130584717,
-0.9638023376464844
])
)

# ensure reference values are calculated accordingly
x_predict = Nx.tensor([-1, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5])

assert_all_close(
CubicSpline.predict(model, x_predict),
Nx.tensor([
26.50000762939453,
8.78125286102295,
-0.15625077486038208,
4.15625,
-3.21875,
-11.28125,
26.90625,
78.50000762939453
])
)
end
end
end

0 comments on commit 99a5286

Please sign in to comment.