diff --git a/lib/scholar/interpolation/linear.ex b/lib/scholar/interpolation/linear.ex index d981eeb8..4d2c7be1 100644 --- a/lib/scholar/interpolation/linear.ex +++ b/lib/scholar/interpolation/linear.ex @@ -131,16 +131,15 @@ defmodule Scholar.Interpolation.Linear do defnp predict_n(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x, opts) do shape = Nx.shape(target_x) - target_x = Nx.flatten(target_x) - indices = Nx.argsort(target_x) left_bound = x[0] right_bound = x[-1] target_x = Nx.sort(target_x) - res = Nx.broadcast(Nx.tensor(0, type: to_float_type(target_x)), {Nx.axis_size(target_x, 0)}) + type = Nx.Type.merge(to_float_type(target_x), coefficients.type) + res = Nx.broadcast(Nx.tensor(0, type: type), {Nx.axis_size(target_x, 0)}) # while with smaller than left_bound {{res, i}, _} = diff --git a/test/scholar/interpolation/linear_test.exs b/test/scholar/interpolation/linear_test.exs index c5980209..0d6431a6 100644 --- a/test/scholar/interpolation/linear_test.exs +++ b/test/scholar/interpolation/linear_test.exs @@ -51,5 +51,16 @@ defmodule Scholar.Interpolation.LinearTest do assert Linear.predict(model, Nx.tensor([[[-0.5], [0.5], [1.5], [2.5], [3.5]]])) == Nx.tensor([[[0.0], [0.0], [0.5], [3], [7]]]) end + + test "with different types" do + x_s = Nx.tensor([1, 2, 3], type: :u64) + y_s = Nx.tensor([1.0, 2.0, 3.0], type: :f64) + target = Nx.tensor([1, 2], type: :u64) + + assert x_s + |> Scholar.Interpolation.Linear.fit(y_s) + |> Scholar.Interpolation.Linear.predict(target) == + Nx.tensor([1.0, 2.0], type: :f64) + end end end