Skip to content

Commit

Permalink
Merge coefficient types on linear interpolation, closes #318
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Jan 17, 2025
1 parent b815c59 commit da18bb3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
5 changes: 2 additions & 3 deletions lib/scholar/interpolation/linear.ex
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,15 @@ defmodule Scholar.Interpolation.Linear do

defnp predict_n(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x, opts) do
shape = Nx.shape(target_x)

target_x = Nx.flatten(target_x)

indices = Nx.argsort(target_x)

left_bound = x[0]
right_bound = x[-1]

target_x = Nx.sort(target_x)
res = Nx.broadcast(Nx.tensor(0, type: to_float_type(target_x)), {Nx.axis_size(target_x, 0)})
type = Nx.Type.merge(to_float_type(target_x), coefficients.type)
res = Nx.broadcast(Nx.tensor(0, type: type), {Nx.axis_size(target_x, 0)})

# while with smaller than left_bound
{{res, i}, _} =
Expand Down
11 changes: 11 additions & 0 deletions test/scholar/interpolation/linear_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,16 @@ defmodule Scholar.Interpolation.LinearTest do
assert Linear.predict(model, Nx.tensor([[[-0.5], [0.5], [1.5], [2.5], [3.5]]])) ==
Nx.tensor([[[0.0], [0.0], [0.5], [3], [7]]])
end

test "with different types" do
x_s = Nx.tensor([1, 2, 3], type: :u64)
y_s = Nx.tensor([1.0, 2.0, 3.0], type: :f64)
target = Nx.tensor([1, 2], type: :u64)

assert x_s
|> Scholar.Interpolation.Linear.fit(y_s)
|> Scholar.Interpolation.Linear.predict(target) ==
Nx.tensor([1.0, 2.0], type: :f64)
end
end
end

0 comments on commit da18bb3

Please sign in to comment.