Skip to content

Commit

Permalink
mix format
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Jan 15, 2025
1 parent 4c27312 commit 9487016
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
26 changes: 24 additions & 2 deletions test/scholar/neighbors/knn_regressor_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,18 @@ defmodule Scholar.Neighbors.KNNRegressorTest do

test "predict with 2D labels" do
y =
Nx.tensor([[1, 4], [0, 3], [2, 5], [0, 3], [0, 3], [1, 4], [2, 5], [0, 3], [1, 4], [2, 5]])
Nx.tensor([
[1, 4],
[0, 3],
[2, 5],
[0, 3],
[0, 3],
[1, 4],
[2, 5],
[0, 3],
[1, 4],
[2, 5]
])

model = KNNRegressor.fit(x_train(), y, num_neighbors: 3)
y_pred = KNNRegressor.predict(model, x())
Expand All @@ -111,7 +122,18 @@ defmodule Scholar.Neighbors.KNNRegressorTest do

test "predict with 2D labels, cosine metric and weights set to :distance" do
y =
Nx.tensor([[1, 4], [0, 3], [2, 5], [0, 3], [0, 3], [1, 4], [2, 5], [0, 3], [1, 4], [2, 5]])
Nx.tensor([
[1, 4],
[0, 3],
[2, 5],
[0, 3],
[0, 3],
[1, 4],
[2, 5],
[0, 3],
[1, 4],
[2, 5]
])

model =
KNNRegressor.fit(x_train(), y, num_neighbors: 3, metric: :cosine, weights: :distance)
Expand Down
13 changes: 12 additions & 1 deletion test/scholar/neighbors/rnn_regressor_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,18 @@ defmodule Scholar.Neighbors.RadiusNNRegressorTest do

test "predict with weights set to :distance and with specific metric and 2d labels" do
y =
Nx.tensor([[1, 4], [0, 3], [2, 5], [0, 3], [0, 3], [1, 4], [2, 5], [0, 3], [1, 4], [2, 5]])
Nx.tensor([
[1, 4],
[0, 3],
[2, 5],
[0, 3],
[0, 3],
[1, 4],
[2, 5],
[0, 3],
[1, 4],
[2, 5]
])

model =
RadiusNNRegressor.fit(x(), y,
Expand Down

0 comments on commit 9487016

Please sign in to comment.