From af2135ecd91596224b55e45141208351b6a5ff55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Tue, 21 May 2024 20:10:54 +0200 Subject: [PATCH 01/13] Add incremental PCA --- lib/scholar/decomposition/incremental_pca.ex | 240 +++++++++++++++++++ lib/scholar/decomposition/pca.ex | 17 +- lib/scholar/neighbors/brute_knn.ex | 21 -- lib/scholar/shared.ex | 21 ++ 4 files changed, 264 insertions(+), 35 deletions(-) create mode 100644 lib/scholar/decomposition/incremental_pca.ex diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex new file mode 100644 index 00000000..b3e6f3e5 --- /dev/null +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -0,0 +1,240 @@ +defmodule Scholar.Decomposition.IncrementalPCA do + @moduledoc """ + Incremental Principal Component Analysis. + + References: + + * [1] - [Incremental Learning for Robust Visual Tracking](https://www.cs.toronto.edu/~dross/ivt/RossLimLinYang_ijcv.pdf) + """ + import Nx.Defn + import Scholar.Shared + require Nx + + @derive {Nx.Container, + keep: [:num_samples_seen, :num_components], + containers: [ + :components, + :singular_values, + :mean, + :variance, + :explained_variance, + :explained_variance_ratio + ]} + defstruct [ + :num_samples_seen, + :num_components, + :components, + :singular_values, + :mean, + :variance, + :explained_variance, + :explained_variance_ratio + ] + + opts = [ + num_components: [ + type: :pos_integer, + doc: "???" + ], + batch_size: [ + type: :pos_integer, + doc: "The number of samples in a batch." + ], + whiten: [ + type: :boolean, + default: false, + doc: "???" + ] + ] + + @opts_schema NimbleOptions.new!(opts) + + deftransform fit(x, opts) do + if Nx.rank(x) != 2 do + raise ArgumentError, + "expected input tensor to have shape {num_samples, num_features}, + got tensor with shape: #{inspect(Nx.shape(x))}" + end + + opts = NimbleOptions.validate!(opts, @opts_schema) + + {num_samples, num_features} = Nx.shape(x) + + batch_size = + if opts[:batch_size] do + opts[:batch_size] + else + 5 * num_features + end + + # TODO: What to do if batch_size is greater than num_samples? Probably raise an error. + + num_components = opts[:num_components] + + num_components = + if num_components do + cond do + num_components > batch_size -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + batch_size = #{batch_size}, got #{num_components} + """ + + num_components > num_samples -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_samples = #{num_samples}, got #{num_components} + """ + + num_components > num_features -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_features = #{num_features}, got #{num_components} + """ + + true -> + num_components + end + else + Enum.min([batch_size, num_samples, num_features]) + end + + opts = Keyword.put(opts, :num_components, num_components) + opts = Keyword.put(opts, :batch_size, batch_size) + + fit_n(x, opts) + end + + defn fit_n(x, opts) do + batch_size = opts[:batch_size] + {batches, leftover} = get_batches(x, batch_size: batch_size) + + num_batches = Nx.axis_size(batches, 0) + model = fit_first(batches[0], opts) + + {model, _} = + while { + model, + { + batches, + i = Nx.u64(1) + } + }, + i < num_batches do + batch = batches[i] + model = partial_fit_n(model, batch) + {model, {batches, i + 1}} + end + + # partial_fit_n(model, leftover) + model + end + + defnp fit_first(x, opts) do + # This is similar to Scholar.Decomposition.PCA.fit_n + {num_samples, _} = Nx.shape(x) + num_components = opts[:num_components] + mean = Nx.mean(x, axes: [0]) + variance = Nx.variance(x, axes: [0]) + x_centered = x - mean + {u, s, vt} = Nx.LinAlg.svd(x_centered, full_matrices?: false) + {_, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt) + components = vt[[0..(num_components - 1)]] + singular_values = s[[0..(num_components - 1)]] + explained_variance = s * s / (num_samples - 1) + + explained_variance_ratio = + (explained_variance / Nx.sum(explained_variance))[[0..(num_components - 1)]] + + %__MODULE__{ + num_samples_seen: num_samples, + num_components: num_components, + components: components, + singular_values: singular_values, + mean: mean, + variance: variance, + explained_variance: explained_variance[[0..(num_components - 1)]], + explained_variance_ratio: explained_variance_ratio + } + end + + defnp partial_fit_n(model, x) do + num_samples_seen = model.num_samples_seen + num_components = model.num_components + components = model.components + singular_values = model.singular_values + mean = model.mean + variance = model.variance + {num_samples, _} = Nx.shape(x) + last_sample_count = Nx.broadcast(num_samples_seen, {Nx.axis_size(x, 1)}) + + {num_total_samples, col_mean, col_variance} = + incremental_mean_and_var(x, last_sample_count, mean, variance) + + num_total_samples = num_total_samples[0] + # if num_samples_seen > 0 do + col_batch_mean = Nx.mean(x, axes: [0]) + x_centered = x - col_batch_mean + + mean_correction = + Nx.sqrt(num_samples_seen / num_total_samples) * num_samples * (mean - col_batch_mean) + + mean_correction = Nx.new_axis(mean_correction, 0) + to_add = Nx.reshape(singular_values, {1, :auto}) * components + z = Nx.concatenate([Nx.transpose(to_add), x_centered, mean_correction], axis: 0) + {u, s, vt} = Nx.LinAlg.svd(z, full_matrices?: false) + {_, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt) + components = vt[[0..(num_components - 1)]] + singular_values = s[[0..(num_components - 1)]] + explained_variance = s * s / (num_total_samples - 1) + + explained_variance_ratio = + singular_values * singular_values / Nx.sum(col_variance * num_total_samples) + + %__MODULE__{ + num_samples_seen: num_total_samples, + num_components: num_components, + components: components, + singular_values: singular_values, + mean: col_mean, + variance: col_variance, + explained_variance: explained_variance, + explained_variance_ratio: explained_variance_ratio + } + end + + defnp incremental_mean_and_var(x, last_sample_count, last_mean, last_variance) do + new_sample_count = Nx.axis_size(x, 0) + updated_sample_count = last_sample_count + new_sample_count + last_sum = last_sample_count * last_mean + new_sum = Nx.sum(x, axes: [0]) + updated_mean = (last_sum + new_sum) / updated_sample_count + t = new_sum / new_sample_count + temp = x - t + correction = Nx.sum(temp, axes: [0]) + temp = temp * temp + + new_unnormalized_variance = + Nx.sum(temp, axes: [0]) - correction * correction / new_sample_count + + last_unnormalized_variance = last_sample_count * last_variance + last_over_new_count = last_sample_count / new_sample_count + + updated_unnormalized_variance = + last_unnormalized_variance + + new_unnormalized_variance + + last_over_new_count / updated_sample_count * + (last_sum / last_over_new_count - new_sum) ** 2 + + zeros = last_sample_count == 0 + + updated_unnormalized_variance = + Nx.select(zeros, new_unnormalized_variance, updated_unnormalized_variance) + + updated_variance = updated_unnormalized_variance / updated_sample_count + {updated_sample_count, updated_mean, updated_variance} + end +end diff --git a/lib/scholar/decomposition/pca.ex b/lib/scholar/decomposition/pca.ex index 97d6d620..c0bcb46f 100644 --- a/lib/scholar/decomposition/pca.ex +++ b/lib/scholar/decomposition/pca.ex @@ -15,15 +15,13 @@ defmodule Scholar.Decomposition.PCA do import Nx.Defn @derive {Nx.Container, - keep: [:num_components], + keep: [:num_components, :num_samples, :num_features], containers: [ :components, :explained_variance, :explained_variance_ratio, :singular_values, - :mean, - :num_features, - :num_samples + :mean ]} defstruct [ :components, @@ -157,7 +155,7 @@ defmodule Scholar.Decomposition.PCA do num_samples ) - {_, components} = flip_svd(decomposer, components) + {_, components} = Scholar.Decomposition.Utils.flip_svd(decomposer, components) components = components[[0..(num_components - 1), ..]] explained_variance = singular_values * singular_values / (num_samples - 1) @@ -289,15 +287,6 @@ defmodule Scholar.Decomposition.PCA do end end - defnp flip_svd(u, v) do - # columns of u, rows of v - max_abs_cols_idx = u |> Nx.abs() |> Nx.argmax(axis: 0, keep_axis: true) - signs = u |> Nx.take_along_axis(max_abs_cols_idx, axis: 0) |> Nx.sign() |> Nx.squeeze() - u = u * signs - v = v * Nx.new_axis(signs, -1) - {u, v} - end - deftransformp calculate_num_components( num_components, num_features, diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 27afa7d6..7ebe437f 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -209,27 +209,6 @@ defmodule Scholar.Neighbors.BruteKNN do {neighbor_indices, neighbor_distances} end - defn get_batches(tensor, opts) do - {size, dim} = Nx.shape(tensor) - batch_size = opts[:batch_size] - num_batches = div(size, batch_size) - leftover_size = rem(size, batch_size) - - batches = - tensor - |> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) - |> Nx.reshape({num_batches, batch_size, dim}) - - leftover = - if leftover_size > 0 do - Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) - else - nil - end - - {batches, leftover} - end - defnp brute_force_search(data, query, opts) do k = opts[:num_neighbors] metric = opts[:metric] diff --git a/lib/scholar/shared.ex b/lib/scholar/shared.ex index 61330283..59a90fe2 100644 --- a/lib/scholar/shared.ex +++ b/lib/scholar/shared.ex @@ -89,4 +89,25 @@ defmodule Scholar.Shared do valid_broadcast(to_parse - 1, n_dims, shape1, shape2) end + + defn get_batches(tensor, opts) do + {size, dim} = Nx.shape(tensor) + batch_size = min(opts[:batch_size], size) + num_batches = div(size, batch_size) + leftover_size = rem(size, batch_size) + + batches = + tensor + |> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) + |> Nx.reshape({num_batches, batch_size, dim}) + + leftover = + if leftover_size > 0 do + Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) + else + nil + end + + {batches, leftover} + end end From 643622abf1c5c56a5c76f630d3807db3f42a62b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Wed, 22 May 2024 13:35:08 +0200 Subject: [PATCH 02/13] Update PCA unit-tests --- lib/scholar/decomposition/pca.ex | 10 +++------- test/scholar/decomposition/pca_test.exs | 12 ++++++------ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/lib/scholar/decomposition/pca.ex b/lib/scholar/decomposition/pca.ex index c0bcb46f..24f223ef 100644 --- a/lib/scholar/decomposition/pca.ex +++ b/lib/scholar/decomposition/pca.ex @@ -123,12 +123,8 @@ defmodule Scholar.Decomposition.PCA do [0.0, 0.0] ), num_components: 2, - num_features: Nx.tensor( - 2 - ), - num_samples: Nx.tensor( - 6 - ) + num_features: 2, + num_samples: 6 } """ deftransform fit(x, opts \\ []) do @@ -277,7 +273,7 @@ defmodule Scholar.Decomposition.PCA do num_samples ) - {decomposer, _components} = flip_svd(decomposer, components) + {decomposer, _components} = Scholar.Decomposition.Utils.flip_svd(decomposer, components) decomposer = decomposer[[.., 0..(num_components - 1)]] if opts[:whiten] do diff --git a/test/scholar/decomposition/pca_test.exs b/test/scholar/decomposition/pca_test.exs index e332b971..a5cec890 100644 --- a/test/scholar/decomposition/pca_test.exs +++ b/test/scholar/decomposition/pca_test.exs @@ -40,8 +40,8 @@ defmodule Scholar.Decomposition.PCATest do assert_all_close(model.singular_values, Nx.tensor([6.30061232, 0.54980396]), atol: 1.0e-3) assert model.mean == Nx.tensor([0.0, 0.0]) assert model.num_components == 2 - assert model.num_samples == Nx.tensor(6) - assert model.num_features == Nx.tensor(2) + assert model.num_samples == 6 + assert model.num_features == 2 end test "fit test - :num_components is integer" do @@ -70,8 +70,8 @@ defmodule Scholar.Decomposition.PCATest do assert_all_close(model.singular_values, Nx.tensor([38.89730453491211]), atol: 1.0e-3) assert model.mean == Nx.tensor([28.5, 2.0, 3.5]) assert model.num_components == 1 - assert model.num_samples == Nx.tensor(2) - assert model.num_features == Nx.tensor(3) + assert model.num_samples == 2 + assert model.num_features == 3 end test "transform test - :whiten set to false" do @@ -121,8 +121,8 @@ defmodule Scholar.Decomposition.PCATest do assert_all_close(model.mean, Nx.tensor([3.83333333, 0.33333333, 1.66666667]), atol: 1.0e-2) assert model.num_components == 2 - assert model.num_samples == Nx.tensor(6) - assert model.num_features == Nx.tensor(3) + assert model.num_samples == 6 + assert model.num_features == 3 assert_all_close( PCA.transform(model, x3()), From bed9afdc0d34255d1e93dfa7cd683d111659c25e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Thu, 18 Jul 2024 20:59:11 +0200 Subject: [PATCH 03/13] Add Incremental PCA --- lib/scholar/decomposition/incremental_pca.ex | 323 ++++++++++++------ lib/scholar/shared.ex | 5 +- .../decomposition/incremental_pca_test.exs | 5 + 3 files changed, 228 insertions(+), 105 deletions(-) create mode 100644 test/scholar/decomposition/incremental_pca_test.exs diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index b3e6f3e5..920c7da3 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -2,61 +2,84 @@ defmodule Scholar.Decomposition.IncrementalPCA do @moduledoc """ Incremental Principal Component Analysis. + Description goes here (elaborate on incremental approach) + References: - * [1] - [Incremental Learning for Robust Visual Tracking](https://www.cs.toronto.edu/~dross/ivt/RossLimLinYang_ijcv.pdf) + * [1] [Incremental Learning for Robust Visual Tracking](https://www.cs.toronto.edu/~dross/ivt/RossLimLinYang_ijcv.pdf) """ import Nx.Defn import Scholar.Shared require Nx @derive {Nx.Container, - keep: [:num_samples_seen, :num_components], + keep: [:num_components, :whiten?], containers: [ :components, :singular_values, + :num_samples_seen, :mean, :variance, :explained_variance, :explained_variance_ratio ]} defstruct [ - :num_samples_seen, :num_components, + :num_samples_seen, :components, :singular_values, :mean, :variance, :explained_variance, - :explained_variance_ratio + :explained_variance_ratio, + :whiten? ] - opts = [ + stream_opts = [ num_components: [ + required: true, type: :pos_integer, - doc: "???" + doc: "The number of principal components." ], - batch_size: [ - type: :pos_integer, - doc: "The number of samples in a batch." - ], - whiten: [ + whiten?: [ type: :boolean, default: false, - doc: "???" + doc: """ + When true the `components` are divided by `num_samples` times `components` to ensure uncorrelated outputs with unit component-wise variances. + + Whitening will remove some information from the transformed signal (the relative variance scales of the components) + but can sometimes improve the predictive accuracy of the downstream estimators by making data respect some hard-wired assumptions. + """ ] ] - @opts_schema NimbleOptions.new!(opts) + tensor_opts = + stream_opts ++ + [ + batch_size: [ + type: :pos_integer, + doc: "The number of samples in a batch." + ] + ] + + @stream_schema NimbleOptions.new!(stream_opts) + @tensor_schema NimbleOptions.new!(tensor_opts) - deftransform fit(x, opts) do + @doc """ + Fits an Incremental PCA model. + + ## Options + + #{NimbleOptions.docs(@tensor_schema)} + """ + deftransform fit(%Nx.Tensor{} = x, opts) do if Nx.rank(x) != 2 do raise ArgumentError, "expected input tensor to have shape {num_samples, num_features}, got tensor with shape: #{inspect(Nx.shape(x))}" end - opts = NimbleOptions.validate!(opts, @opts_schema) + opts = NimbleOptions.validate!(opts, @tensor_schema) {num_samples, num_features} = Nx.shape(x) @@ -67,42 +90,35 @@ defmodule Scholar.Decomposition.IncrementalPCA do 5 * num_features end - # TODO: What to do if batch_size is greater than num_samples? Probably raise an error. - num_components = opts[:num_components] - num_components = - if num_components do - cond do - num_components > batch_size -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - batch_size = #{batch_size}, got #{num_components} - """ - - num_components > num_samples -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - num_samples = #{num_samples}, got #{num_components} - """ - - num_components > num_features -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - num_features = #{num_features}, got #{num_components} - """ - - true -> - num_components - end - else - Enum.min([batch_size, num_samples, num_features]) - end + cond do + num_components > batch_size -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + batch_size = #{batch_size}, got #{num_components} + """ + + num_components > num_samples -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_samples = #{num_samples}, got #{num_components} + """ + + num_components > num_features -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_features = #{num_features}, got #{num_components} + """ + + true -> + nil + end - opts = Keyword.put(opts, :num_components, num_components) + # TODO: Check this! opts = Keyword.put(opts, :batch_size, batch_size) fit_n(x, opts) @@ -110,10 +126,12 @@ defmodule Scholar.Decomposition.IncrementalPCA do defn fit_n(x, opts) do batch_size = opts[:batch_size] - {batches, leftover} = get_batches(x, batch_size: batch_size) + + {batches, leftover} = + get_batches(x, batch_size: batch_size, min_batch_size: opts[:num_components]) num_batches = Nx.axis_size(batches, 0) - model = fit_first(batches[0], opts) + model = fit_first_n(batches[0], opts) {model, _} = while { @@ -129,13 +147,61 @@ defmodule Scholar.Decomposition.IncrementalPCA do {model, {batches, i + 1}} end - # partial_fit_n(model, leftover) + model = + case leftover do + nil -> model + _ -> partial_fit_n(model, leftover) + end + model end - defnp fit_first(x, opts) do + @doc """ + Fits an Incremental PCA model on a stream of batches. + + ## Options + + #{NimbleOptions.docs(@stream_schema)} + """ + def fit(batches = %Stream{}, opts) do + opts = NimbleOptions.validate!(opts, @stream_schema) + # This should not run the stream + first_batch = Enum.at(batches, 0) + model = fit_first_n(first_batch, opts) + batches = Stream.drop(batches, 1) + + Enum.reduce( + batches, + model, + # TODO: JIT + fn batch, model -> partial_fit(model, batch) end + ) + end + + deftransformp validate_batch(batch, opts) do + {batch_size, num_features} = Nx.shape(batch) + num_components = opts[:num_components] + + cond do + num_components > batch_size -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + batch_size = #{batch_size}, got #{num_components} + """ + + num_components > num_features -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_features = #{num_features}, got #{num_components} + """ + end + end + + defnp fit_first_n(x, opts) do # This is similar to Scholar.Decomposition.PCA.fit_n - {num_samples, _} = Nx.shape(x) + num_samples = Nx.u64(Nx.axis_size(x, 0)) num_components = opts[:num_components] mean = Nx.mean(x, axes: [0]) variance = Nx.variance(x, axes: [0]) @@ -157,84 +223,135 @@ defmodule Scholar.Decomposition.IncrementalPCA do mean: mean, variance: variance, explained_variance: explained_variance[[0..(num_components - 1)]], - explained_variance_ratio: explained_variance_ratio + explained_variance_ratio: explained_variance_ratio, + whiten?: opts[:whiten?] } end + @doc """ + Updates an Incremental PCA model on samples `x`. + """ + deftransform partial_fit(model, x) do + {num_samples, num_features} = Nx.shape(x) + num_components = model.num_components + + cond do + num_components > num_samples -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + batch_size = #{num_samples}, got #{num_components} + """ + + num_components > num_features -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_features = #{num_features}, got #{num_components} + """ + end + + partial_fit_n(model, x) + end + defnp partial_fit_n(model, x) do - num_samples_seen = model.num_samples_seen num_components = model.num_components components = model.components singular_values = model.singular_values + num_samples_seen = model.num_samples_seen mean = model.mean variance = model.variance {num_samples, _} = Nx.shape(x) - last_sample_count = Nx.broadcast(num_samples_seen, {Nx.axis_size(x, 1)}) - - {num_total_samples, col_mean, col_variance} = - incremental_mean_and_var(x, last_sample_count, mean, variance) - num_total_samples = num_total_samples[0] - # if num_samples_seen > 0 do - col_batch_mean = Nx.mean(x, axes: [0]) - x_centered = x - col_batch_mean + {x_centered, x_mean, new_num_samples_seen, new_mean, new_variance} = + incremental_mean_and_variance(x, num_samples_seen, mean, variance) mean_correction = - Nx.sqrt(num_samples_seen / num_total_samples) * num_samples * (mean - col_batch_mean) + Nx.sqrt(num_samples_seen / new_num_samples_seen) * num_samples * (mean - x_mean) mean_correction = Nx.new_axis(mean_correction, 0) - to_add = Nx.reshape(singular_values, {1, :auto}) * components - z = Nx.concatenate([Nx.transpose(to_add), x_centered, mean_correction], axis: 0) - {u, s, vt} = Nx.LinAlg.svd(z, full_matrices?: false) + + matrix = + Nx.concatenate( + [ + Nx.new_axis(singular_values, 1) * components, + x_centered, + mean_correction + ], + axis: 0 + ) + + {u, s, vt} = Nx.LinAlg.svd(matrix, full_matrices?: false) {_, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt) - components = vt[[0..(num_components - 1)]] - singular_values = s[[0..(num_components - 1)]] - explained_variance = s * s / (num_total_samples - 1) + new_components = vt[[0..(num_components - 1)]] + new_singular_values = s[[0..(num_components - 1)]] + new_explained_variance = singular_values * singular_values / (new_num_samples_seen - 1) - explained_variance_ratio = - singular_values * singular_values / Nx.sum(col_variance * num_total_samples) + new_explained_variance_ratio = + singular_values * singular_values / Nx.sum(new_variance * new_num_samples_seen) %__MODULE__{ - num_samples_seen: num_total_samples, num_components: num_components, - components: components, - singular_values: singular_values, - mean: col_mean, - variance: col_variance, - explained_variance: explained_variance, - explained_variance_ratio: explained_variance_ratio + components: new_components, + singular_values: new_singular_values, + num_samples_seen: new_num_samples_seen, + mean: new_mean, + variance: new_variance, + explained_variance: new_explained_variance, + explained_variance_ratio: new_explained_variance_ratio, + whiten?: model.whiten? } end - defnp incremental_mean_and_var(x, last_sample_count, last_mean, last_variance) do - new_sample_count = Nx.axis_size(x, 0) - updated_sample_count = last_sample_count + new_sample_count - last_sum = last_sample_count * last_mean - new_sum = Nx.sum(x, axes: [0]) - updated_mean = (last_sum + new_sum) / updated_sample_count - t = new_sum / new_sample_count - temp = x - t - correction = Nx.sum(temp, axes: [0]) - temp = temp * temp + defnp incremental_mean_and_variance(x, num_samples_seen, mean, variance) do + num_samples = Nx.axis_size(x, 0) + new_num_samples_seen = num_samples_seen + num_samples + sum = num_samples_seen * mean + x_sum = Nx.sum(x, axes: [0]) + new_mean = (sum + x_sum) / new_num_samples_seen + x_mean = x_sum / num_samples + x_centered = x - x_mean + correction = Nx.sum(x_centered, axes: [0]) - new_unnormalized_variance = - Nx.sum(temp, axes: [0]) - correction * correction / new_sample_count + x_unnormalized_variance = + Nx.sum(x_centered * x_centered, axes: [0]) - correction * correction / num_samples - last_unnormalized_variance = last_sample_count * last_variance - last_over_new_count = last_sample_count / new_sample_count + unnormalized_variance = num_samples_seen * variance + last_over_new_count = num_samples_seen / num_samples - updated_unnormalized_variance = - last_unnormalized_variance + - new_unnormalized_variance + - last_over_new_count / updated_sample_count * - (last_sum / last_over_new_count - new_sum) ** 2 + new_unnormalized_variance = + unnormalized_variance + + x_unnormalized_variance + + last_over_new_count / new_num_samples_seen * + (sum / last_over_new_count - x_sum) ** 2 - zeros = last_sample_count == 0 + new_variance = new_unnormalized_variance / new_num_samples_seen + {x_centered, x_mean, new_num_samples_seen, new_mean, new_variance} + end - updated_unnormalized_variance = - Nx.select(zeros, new_unnormalized_variance, updated_unnormalized_variance) + @doc """ + Documentation goes here. + """ + deftransform transform(model, x) do + transform_n(model, x) + end - updated_variance = updated_unnormalized_variance / updated_sample_count - {updated_sample_count, updated_mean, updated_variance} + defnp transform_n( + %__MODULE__{ + components: components, + explained_variance: explained_variance, + mean: mean, + whiten?: whiten? + } = _model, + x + ) do + # This is the same as Scholar.Decomposition.PCA.transform_n! + z = Nx.dot(x - mean, [1], components, [1]) + + if whiten? do + z / Nx.sqrt(explained_variance) + else + z + end end end diff --git a/lib/scholar/shared.ex b/lib/scholar/shared.ex index 59a90fe2..02690ef7 100644 --- a/lib/scholar/shared.ex +++ b/lib/scholar/shared.ex @@ -92,7 +92,8 @@ defmodule Scholar.Shared do defn get_batches(tensor, opts) do {size, dim} = Nx.shape(tensor) - batch_size = min(opts[:batch_size], size) + batch_size = min(size, opts[:batch_size]) + min_batch_size = if opts[:min_batch_size], do: opts[:min_batch_size], else: 0 num_batches = div(size, batch_size) leftover_size = rem(size, batch_size) @@ -102,7 +103,7 @@ defmodule Scholar.Shared do |> Nx.reshape({num_batches, batch_size, dim}) leftover = - if leftover_size > 0 do + if leftover_size > min_batch_size do Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) else nil diff --git a/test/scholar/decomposition/incremental_pca_test.exs b/test/scholar/decomposition/incremental_pca_test.exs new file mode 100644 index 00000000..8534d0a9 --- /dev/null +++ b/test/scholar/decomposition/incremental_pca_test.exs @@ -0,0 +1,5 @@ +defmodule Scholar.Decomposition.IncrementalPCATest do + use Scholar.Case, async: true + alias Scholar.Decomposition.IncrementalPCA + doctest IncrementalPCA +end From a8adff4fb0ba8204c3149f254bf8ad80a2fe2598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Thu, 18 Jul 2024 21:13:17 +0200 Subject: [PATCH 04/13] Fix PCA and its tests. --- lib/scholar/decomposition/pca.ex | 11 +++++++---- test/scholar/decomposition/pca_test.exs | 6 +++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/lib/scholar/decomposition/pca.ex b/lib/scholar/decomposition/pca.ex index 24f223ef..699947cc 100644 --- a/lib/scholar/decomposition/pca.ex +++ b/lib/scholar/decomposition/pca.ex @@ -15,9 +15,10 @@ defmodule Scholar.Decomposition.PCA do import Nx.Defn @derive {Nx.Container, - keep: [:num_components, :num_samples, :num_features], + keep: [:num_components, :num_features], containers: [ :components, + :num_samples, :explained_variance, :explained_variance_ratio, :singular_values, @@ -28,6 +29,8 @@ defmodule Scholar.Decomposition.PCA do :explained_variance, :explained_variance_ratio, :singular_values, + :num_samples, + :num_features, :mean, :num_components, :num_features, @@ -124,7 +127,7 @@ defmodule Scholar.Decomposition.PCA do ), num_components: 2, num_features: 2, - num_samples: 6 + num_samples: Nx.u64(6) } """ deftransform fit(x, opts \\ []) do @@ -166,8 +169,8 @@ defmodule Scholar.Decomposition.PCA do singular_values: singular_values[[0..(num_components - 1)]], mean: mean, num_components: num_components, - num_features: num_features, - num_samples: num_samples + num_features: Nx.u64(num_features), + num_samples: Nx.u64(num_samples) } end diff --git a/test/scholar/decomposition/pca_test.exs b/test/scholar/decomposition/pca_test.exs index a5cec890..b4f43b14 100644 --- a/test/scholar/decomposition/pca_test.exs +++ b/test/scholar/decomposition/pca_test.exs @@ -40,7 +40,7 @@ defmodule Scholar.Decomposition.PCATest do assert_all_close(model.singular_values, Nx.tensor([6.30061232, 0.54980396]), atol: 1.0e-3) assert model.mean == Nx.tensor([0.0, 0.0]) assert model.num_components == 2 - assert model.num_samples == 6 + assert model.num_samples == Nx.u64(6) assert model.num_features == 2 end @@ -70,7 +70,7 @@ defmodule Scholar.Decomposition.PCATest do assert_all_close(model.singular_values, Nx.tensor([38.89730453491211]), atol: 1.0e-3) assert model.mean == Nx.tensor([28.5, 2.0, 3.5]) assert model.num_components == 1 - assert model.num_samples == 2 + assert model.num_samples == Nx.u64(2) assert model.num_features == 3 end @@ -121,7 +121,7 @@ defmodule Scholar.Decomposition.PCATest do assert_all_close(model.mean, Nx.tensor([3.83333333, 0.33333333, 1.66666667]), atol: 1.0e-2) assert model.num_components == 2 - assert model.num_samples == 6 + assert model.num_samples == Nx.u64(6) assert model.num_features == 3 assert_all_close( From d7dfde1265c75de7fde8e32ac5df05968c91721e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Thu, 18 Jul 2024 21:29:20 +0200 Subject: [PATCH 05/13] Bug fix. --- lib/scholar/decomposition/incremental_pca.ex | 2 -- lib/scholar/decomposition/pca.ex | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index 920c7da3..5f8f23ec 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -118,7 +118,6 @@ defmodule Scholar.Decomposition.IncrementalPCA do nil end - # TODO: Check this! opts = Keyword.put(opts, :batch_size, batch_size) fit_n(x, opts) @@ -165,7 +164,6 @@ defmodule Scholar.Decomposition.IncrementalPCA do """ def fit(batches = %Stream{}, opts) do opts = NimbleOptions.validate!(opts, @stream_schema) - # This should not run the stream first_batch = Enum.at(batches, 0) model = fit_first_n(first_batch, opts) batches = Stream.drop(batches, 1) diff --git a/lib/scholar/decomposition/pca.ex b/lib/scholar/decomposition/pca.ex index 699947cc..59bb70e0 100644 --- a/lib/scholar/decomposition/pca.ex +++ b/lib/scholar/decomposition/pca.ex @@ -32,9 +32,7 @@ defmodule Scholar.Decomposition.PCA do :num_samples, :num_features, :mean, - :num_components, - :num_features, - :num_samples + :num_components ] fit_opts_schema = [ From 1a4ba4ae40e01c2faf4b2ede81a88a9d5911ac46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Fri, 19 Jul 2024 15:07:31 +0200 Subject: [PATCH 06/13] Single stream run --- lib/scholar/decomposition/incremental_pca.ex | 61 +++++++------------- 1 file changed, 21 insertions(+), 40 deletions(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index 5f8f23ec..6f3835a0 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -25,9 +25,9 @@ defmodule Scholar.Decomposition.IncrementalPCA do ]} defstruct [ :num_components, - :num_samples_seen, :components, :singular_values, + :num_samples_seen, :mean, :variance, :explained_variance, @@ -62,8 +62,8 @@ defmodule Scholar.Decomposition.IncrementalPCA do ] ] - @stream_schema NimbleOptions.new!(stream_opts) @tensor_schema NimbleOptions.new!(tensor_opts) + @stream_schema NimbleOptions.new!(stream_opts) @doc """ Fits an Incremental PCA model. @@ -124,13 +124,15 @@ defmodule Scholar.Decomposition.IncrementalPCA do end defn fit_n(x, opts) do + num_components = opts[:num_components] batch_size = opts[:batch_size] {batches, leftover} = - get_batches(x, batch_size: batch_size, min_batch_size: opts[:num_components]) + get_batches(x, batch_size: batch_size, min_batch_size: num_components) num_batches = Nx.axis_size(batches, 0) - model = fit_first_n(batches[0], opts) + + model = fit_head_n(batches[0], opts) {model, _} = while { @@ -146,13 +148,10 @@ defmodule Scholar.Decomposition.IncrementalPCA do {model, {batches, i + 1}} end - model = - case leftover do - nil -> model - _ -> partial_fit_n(model, leftover) - end - - model + case leftover do + nil -> model + _ -> partial_fit_n(model, leftover) + end end @doc """ @@ -164,43 +163,21 @@ defmodule Scholar.Decomposition.IncrementalPCA do """ def fit(batches = %Stream{}, opts) do opts = NimbleOptions.validate!(opts, @stream_schema) - first_batch = Enum.at(batches, 0) - model = fit_first_n(first_batch, opts) - batches = Stream.drop(batches, 1) Enum.reduce( batches, - model, - # TODO: JIT - fn batch, model -> partial_fit(model, batch) end + nil, + fn batch, model -> fit_batch(model, batch, opts) end ) end - deftransformp validate_batch(batch, opts) do - {batch_size, num_features} = Nx.shape(batch) - num_components = opts[:num_components] - - cond do - num_components > batch_size -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - batch_size = #{batch_size}, got #{num_components} - """ + defp fit_batch(nil, batch, opts), do: fit_head_n(batch, opts) + defp fit_batch(%__MODULE__{} = model, batch, _opts), do: partial_fit(model, batch) - num_components > num_features -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - num_features = #{num_features}, got #{num_components} - """ - end - end - - defnp fit_first_n(x, opts) do + defnp fit_head_n(x, opts) do # This is similar to Scholar.Decomposition.PCA.fit_n - num_samples = Nx.u64(Nx.axis_size(x, 0)) num_components = opts[:num_components] + num_samples = Nx.u64(Nx.axis_size(x, 0)) mean = Nx.mean(x, axes: [0]) variance = Nx.variance(x, axes: [0]) x_centered = x - mean @@ -247,6 +224,9 @@ defmodule Scholar.Decomposition.IncrementalPCA do num_components must be less than or equal to \ num_features = #{num_features}, got #{num_components} """ + + true -> + nil end partial_fit_n(model, x) @@ -254,9 +234,10 @@ defmodule Scholar.Decomposition.IncrementalPCA do defnp partial_fit_n(model, x) do num_components = model.num_components + num_samples_seen = model.num_samples_seen + components = model.components singular_values = model.singular_values - num_samples_seen = model.num_samples_seen mean = model.mean variance = model.variance {num_samples, _} = Nx.shape(x) From 7e56f07c9cd9421cabe8873c186be17521fc4f97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Fri, 19 Jul 2024 20:46:04 +0200 Subject: [PATCH 07/13] Remove fit clauses that works with tensor, add docs --- lib/scholar/decomposition/incremental_pca.ex | 185 ++++++++----------- lib/scholar/decomposition/pca.ex | 1 + lib/scholar/neighbors/brute_knn.ex | 22 +++ lib/scholar/shared.ex | 22 --- 4 files changed, 95 insertions(+), 135 deletions(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index 6f3835a0..89c6af9c 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -2,7 +2,10 @@ defmodule Scholar.Decomposition.IncrementalPCA do @moduledoc """ Incremental Principal Component Analysis. - Description goes here (elaborate on incremental approach) + Performs linear dimensionality reduction by processing the input data + in batches and incrementally updating the principal components. + This iterative approach is particularly suitable for datasets too large to fit in memory, + as its memory complexity is independent of the number of data samples. References: @@ -13,7 +16,7 @@ defmodule Scholar.Decomposition.IncrementalPCA do require Nx @derive {Nx.Container, - keep: [:num_components, :whiten?], + keep: [:whiten?], containers: [ :components, :singular_values, @@ -24,7 +27,6 @@ defmodule Scholar.Decomposition.IncrementalPCA do :explained_variance_ratio ]} defstruct [ - :num_components, :components, :singular_values, :num_samples_seen, @@ -35,7 +37,7 @@ defmodule Scholar.Decomposition.IncrementalPCA do :whiten? ] - stream_opts = [ + opts = [ num_components: [ required: true, type: :pos_integer, @@ -53,116 +55,50 @@ defmodule Scholar.Decomposition.IncrementalPCA do ] ] - tensor_opts = - stream_opts ++ - [ - batch_size: [ - type: :pos_integer, - doc: "The number of samples in a batch." - ] - ] - - @tensor_schema NimbleOptions.new!(tensor_opts) - @stream_schema NimbleOptions.new!(stream_opts) + @opts_schema NimbleOptions.new!(opts) @doc """ - Fits an Incremental PCA model. + Fits an Incremental PCA model on a stream of batches. ## Options - #{NimbleOptions.docs(@tensor_schema)} - """ - deftransform fit(%Nx.Tensor{} = x, opts) do - if Nx.rank(x) != 2 do - raise ArgumentError, - "expected input tensor to have shape {num_samples, num_features}, - got tensor with shape: #{inspect(Nx.shape(x))}" - end + #{NimbleOptions.docs(@opts_schema)} - opts = NimbleOptions.validate!(opts, @tensor_schema) + ## Return values - {num_samples, num_features} = Nx.shape(x) + The function returns a struct with the following parameters: - batch_size = - if opts[:batch_size] do - opts[:batch_size] - else - 5 * num_features - end + * `:num_components` - The number of principal components. - num_components = opts[:num_components] + * `:components` - Principal axes in feature space, representing the directions of maximum variance in the data. + Equivalently, the right singular vectors of the centered input data, parallel to its eigenvectors. + The components are sorted by decreasing `:explained_variance`. - cond do - num_components > batch_size -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - batch_size = #{batch_size}, got #{num_components} - """ + * `:singular_values` - The singular values corresponding to each of the selected components. + The singular values are equal to the 2-norms of the `:num_components` variables in the lower-dimensional space. - num_components > num_samples -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - num_samples = #{num_samples}, got #{num_components} - """ + * `:num_samples_seen` - The number of data samples processed. - num_components > num_features -> - raise ArgumentError, - """ - num_components must be less than or equal to \ - num_features = #{num_features}, got #{num_components} - """ + * `:mean` - Per-feature empirical mean. - true -> - nil - end + * `:variance` - Per-feature empirical variance. - opts = Keyword.put(opts, :batch_size, batch_size) + * `:explained_variance` - Variance explained by each of the selected components. - fit_n(x, opts) - end + * `:explained_variance_ratio` - Percentage of variance explained by each of the selected components. - defn fit_n(x, opts) do - num_components = opts[:num_components] - batch_size = opts[:batch_size] - - {batches, leftover} = - get_batches(x, batch_size: batch_size, min_batch_size: num_components) - - num_batches = Nx.axis_size(batches, 0) - - model = fit_head_n(batches[0], opts) - - {model, _} = - while { - model, - { - batches, - i = Nx.u64(1) - } - }, - i < num_batches do - batch = batches[i] - model = partial_fit_n(model, batch) - {model, {batches, i + 1}} - end - - case leftover do - nil -> model - _ -> partial_fit_n(model, leftover) - end - end - - @doc """ - Fits an Incremental PCA model on a stream of batches. + * `:whiten?` - Whether to apply whitening. - ## Options + ## Examples - #{NimbleOptions.docs(@stream_schema)} + iex> batches = Scidata.Iris.download() |> elem(0) |> Nx.tensor() |> Nx.to_batched(10) + iex> ipca = Scholar.Decomposition.IncrementalPCA.fit(batches, num_components: 2) + iex> ipca.components + iex> ipca.singular_values """ def fit(batches = %Stream{}, opts) do - opts = NimbleOptions.validate!(opts, @stream_schema) + opts = NimbleOptions.validate!(opts, @opts_schema) + IO.puts("num_components: #{opts[:num_components]}") Enum.reduce( batches, @@ -174,6 +110,29 @@ defmodule Scholar.Decomposition.IncrementalPCA do defp fit_batch(nil, batch, opts), do: fit_head_n(batch, opts) defp fit_batch(%__MODULE__{} = model, batch, _opts), do: partial_fit(model, batch) + deftransformp fit_head(x, opts) do + {num_samples, num_features} = Nx.shape(x) + num_components = opts[:num_componenets] + + cond do + num_components > num_samples -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + batch_size = #{num_samples}, got #{num_components} + """ + + num_components > num_features -> + raise ArgumentError, + """ + num_components must be less than or equal to \ + num_features = #{num_features}, got #{num_components} + """ + end + + fit_head_n(x, opts) + end + defnp fit_head_n(x, opts) do # This is similar to Scholar.Decomposition.PCA.fit_n num_components = opts[:num_components] @@ -191,10 +150,9 @@ defmodule Scholar.Decomposition.IncrementalPCA do (explained_variance / Nx.sum(explained_variance))[[0..(num_components - 1)]] %__MODULE__{ - num_samples_seen: num_samples, - num_components: num_components, components: components, singular_values: singular_values, + num_samples_seen: num_samples, mean: mean, variance: variance, explained_variance: explained_variance[[0..(num_components - 1)]], @@ -203,26 +161,24 @@ defmodule Scholar.Decomposition.IncrementalPCA do } end - @doc """ - Updates an Incremental PCA model on samples `x`. - """ + @doc false deftransform partial_fit(model, x) do + {num_components, num_features_seen} = Nx.shape(model.components) {num_samples, num_features} = Nx.shape(x) - num_components = model.num_components cond do - num_components > num_samples -> + num_features_seen != num_features -> raise ArgumentError, """ - num_components must be less than or equal to \ - batch_size = #{num_samples}, got #{num_components} + each batch must have the same number of features, \ + got #{num_features_seen} and #{num_features} """ - num_components > num_features -> + num_components > num_samples -> raise ArgumentError, """ num_components must be less than or equal to \ - num_features = #{num_features}, got #{num_components} + batch_size = #{num_samples}, got #{num_components} """ true -> @@ -233,16 +189,15 @@ defmodule Scholar.Decomposition.IncrementalPCA do end defnp partial_fit_n(model, x) do - num_components = model.num_components - num_samples_seen = model.num_samples_seen - components = model.components + num_components = Nx.axis_size(components, 0) singular_values = model.singular_values + num_samples_seen = model.num_samples_seen mean = model.mean variance = model.variance {num_samples, _} = Nx.shape(x) - {x_centered, x_mean, new_num_samples_seen, new_mean, new_variance} = + {x_mean, x_centered, new_num_samples_seen, new_mean, new_variance} = incremental_mean_and_variance(x, num_samples_seen, mean, variance) mean_correction = @@ -270,7 +225,6 @@ defmodule Scholar.Decomposition.IncrementalPCA do singular_values * singular_values / Nx.sum(new_variance * new_num_samples_seen) %__MODULE__{ - num_components: num_components, components: new_components, singular_values: new_singular_values, num_samples_seen: new_num_samples_seen, @@ -305,11 +259,16 @@ defmodule Scholar.Decomposition.IncrementalPCA do (sum / last_over_new_count - x_sum) ** 2 new_variance = new_unnormalized_variance / new_num_samples_seen - {x_centered, x_mean, new_num_samples_seen, new_mean, new_variance} + {x_mean, x_centered, new_num_samples_seen, new_mean, new_variance} end @doc """ - Documentation goes here. + Applies dimensionality reduction to the data `x` using Incremental PCA `model`. + + ## Examples + + iex> batches = Scidata.Iris.download() |> elem(0) |> Nx.tensor() |> Nx.to_batched(10) + iex> ipca = Scholar.Decomposition.IncrementalPCA.fit(batches, num_components: 2) """ deftransform transform(model, x) do transform_n(model, x) @@ -324,7 +283,7 @@ defmodule Scholar.Decomposition.IncrementalPCA do } = _model, x ) do - # This is the same as Scholar.Decomposition.PCA.transform_n! + # This is literally the same as Scholar.Decomposition.PCA.transform_n! z = Nx.dot(x - mean, [1], components, [1]) if whiten? do diff --git a/lib/scholar/decomposition/pca.ex b/lib/scholar/decomposition/pca.ex index 59bb70e0..dae79826 100644 --- a/lib/scholar/decomposition/pca.ex +++ b/lib/scholar/decomposition/pca.ex @@ -102,6 +102,7 @@ defmodule Scholar.Decomposition.PCA do * `:num_samples` - Number of samples in the training data. ## Examples + iex> x = Nx.tensor([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) iex> Scholar.Decomposition.PCA.fit(x) %Scholar.Decomposition.PCA{ diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index c0176117..6c4ce5eb 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -215,6 +215,28 @@ defmodule Scholar.Neighbors.BruteKNN do {neighbor_indices, neighbor_distances} end + defn get_batches(tensor, opts) do + {size, dim} = Nx.shape(tensor) + batch_size = min(size, opts[:batch_size]) + min_batch_size = if opts[:min_batch_size], do: opts[:min_batch_size], else: 0 + num_batches = div(size, batch_size) + leftover_size = rem(size, batch_size) + + batches = + tensor + |> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) + |> Nx.reshape({num_batches, batch_size, dim}) + + leftover = + if leftover_size > min_batch_size do + Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) + else + nil + end + + {batches, leftover} + end + defnp brute_force_search(data, query, opts) do k = opts[:num_neighbors] metric = opts[:metric] diff --git a/lib/scholar/shared.ex b/lib/scholar/shared.ex index 02690ef7..61330283 100644 --- a/lib/scholar/shared.ex +++ b/lib/scholar/shared.ex @@ -89,26 +89,4 @@ defmodule Scholar.Shared do valid_broadcast(to_parse - 1, n_dims, shape1, shape2) end - - defn get_batches(tensor, opts) do - {size, dim} = Nx.shape(tensor) - batch_size = min(size, opts[:batch_size]) - min_batch_size = if opts[:min_batch_size], do: opts[:min_batch_size], else: 0 - num_batches = div(size, batch_size) - leftover_size = rem(size, batch_size) - - batches = - tensor - |> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) - |> Nx.reshape({num_batches, batch_size, dim}) - - leftover = - if leftover_size > min_batch_size do - Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) - else - nil - end - - {batches, leftover} - end end From b773a948788954f2d5ce07aca357419a7782c5b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Fri, 19 Jul 2024 20:51:28 +0200 Subject: [PATCH 08/13] Remove unused import. --- lib/scholar/decomposition/incremental_pca.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index 89c6af9c..f030d017 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -12,7 +12,6 @@ defmodule Scholar.Decomposition.IncrementalPCA do * [1] [Incremental Learning for Robust Visual Tracking](https://www.cs.toronto.edu/~dross/ivt/RossLimLinYang_ijcv.pdf) """ import Nx.Defn - import Scholar.Shared require Nx @derive {Nx.Container, From a50d30ca6196a558e0c755dd0ac5986b34762915 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Mon, 22 Jul 2024 10:31:10 +0200 Subject: [PATCH 09/13] Update docstrings and add scidata to deps --- lib/scholar/decomposition/incremental_pca.ex | 31 +++++++++++++++++--- mix.exs | 3 +- mix.lock | 22 ++++++++------ 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index f030d017..7a913250 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -90,14 +90,22 @@ defmodule Scholar.Decomposition.IncrementalPCA do ## Examples - iex> batches = Scidata.Iris.download() |> elem(0) |> Nx.tensor() |> Nx.to_batched(10) + iex> {x, _} = Scidata.Iris.download() + iex> batches = x |> Nx.tensor() |> Nx.to_batched(10) iex> ipca = Scholar.Decomposition.IncrementalPCA.fit(batches, num_components: 2) iex> ipca.components + Nx.tensor( + f64[2][4] + [ + [-0.333540033447479, 0.10489666948487557, -0.8618107080105579, -0.367464336197646], + [-0.586203017375807, -0.7916955422591979, 0.158744098990766, -0.06621559023520115] + ] + ) iex> ipca.singular_values + Nx.tensor([77.05782028025969, 10.137862896272168]) """ def fit(batches = %Stream{}, opts) do opts = NimbleOptions.validate!(opts, @opts_schema) - IO.puts("num_components: #{opts[:num_components]}") Enum.reduce( batches, @@ -266,8 +274,23 @@ defmodule Scholar.Decomposition.IncrementalPCA do ## Examples - iex> batches = Scidata.Iris.download() |> elem(0) |> Nx.tensor() |> Nx.to_batched(10) - iex> ipca = Scholar.Decomposition.IncrementalPCA.fit(batches, num_components: 2) + iex> {x, _} = Scidata.Iris.download() + iex> batches = x |> Nx.tensor() |> Nx.to_batched(10) + iex> x = Nx.tensor( + [ + [5.2, 2.6, 2.475, 0.7], + [6.1, 3.2, 3.95, 1.3], + [7.0, 3.8, 5.425, 1.9] + ] + ) + iex> Scholar.Decomposition.IncrementalPCA.transform(ipca, x) + Nx.tensor( + [ + [1.4564743682550334, 0.5657988895852432], + [-0.2724231831356622, -0.24238310361929516], + [-2.001320781438254, -1.050564912015664] + ] + ) """ deftransform transform(model, x) do transform_n(model, x) diff --git a/mix.exs b/mix.exs index b1d53755..439a86cd 100644 --- a/mix.exs +++ b/mix.exs @@ -33,7 +33,8 @@ defmodule Scholar.MixProject do {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, ">= 0.0.0", only: :test}, {:polaris, "~> 0.1"}, - {:benchee, "~> 1.0", only: :dev} + {:benchee, "~> 1.0", only: :dev}, + {:scidata, "~> 0.1.11", only: :test} ] end diff --git a/mix.lock b/mix.lock index a16434fd..069f9d54 100644 --- a/mix.lock +++ b/mix.lock @@ -1,19 +1,23 @@ %{ - "benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"}, + "benchee": {:hex, :benchee, "1.3.1", "c786e6a76321121a44229dde3988fc772bca73ea75170a73fd5f4ddf1af95ccf", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: true]}], "hexpm", "76224c58ea1d0391c8309a8ecbfe27d71062878f59bd41a390266bf4ac1cc56d"}, + "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, - "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, - "ex_doc": {:hex, :ex_doc, "0.34.0", "ab95e0775db3df71d30cf8d78728dd9261c355c81382bcd4cefdc74610bef13e", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "60734fb4c1353f270c3286df4a0d51e65a2c1d9fba66af3940847cc65a8066d7"}, - "exla": {:hex, :exla, "0.7.1", "790493288cf4441abed98df0c4e98da15a2e3a7fa27cd2a1f74ec0693952c579", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "ec9c1698a9a17b859d79f9b3c1d75c370335580cdd0353db9c2017f86155e2ec"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, + "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, + "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, + "exla": {:hex, :exla, "0.7.3", "51310270a0976974fc758f7b28ebd6ca8e099b3d6fc78b0d484c808e977cb914", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5b3d5741a24aada21d3b0feb4b99d1fc3c8457f995a63ea16684d8d5678b96ff"}, + "jason": {:hex, :jason, "1.4.3", "d3f984eeb96fe53b85d20e0b049f03e57d075b5acda3ac8d465c969a2536c17b", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "9a90e868927f7c777689baa16d86f4d0e086d968db5c05d917ccff6d443e58a3"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, - "makeup_erlang": {:hex, :makeup_erlang, "1.0.0", "6f0eff9c9c489f26b69b61440bf1b238d95badae49adac77973cbacae87e3c2e", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "ea7a9307de9d1548d2a72d299058d1fd2339e3d398560a0e46c27dab4891e4d2"}, - "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, + "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, + "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, + "nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, - "nimble_pool": {:hex, :nimble_pool, "1.0.0", "5eb82705d138f4dd4423f69ceb19ac667b3b492ae570c9f5c900bb3d2f50a847", [:mix], [], "hexpm", "80be3b882d2d351882256087078e1b1952a28bf98d0a287be87e4a24a710b67a"}, - "nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"}, + "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, + "nx": {:hex, :nx, "0.7.3", "51ff45d9f9ff58b616f4221fa54ccddda98f30319bb8caaf86695234a469017a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "5ff29af84f08db9bda66b8ef7ce92ab583ab4f983629fe00b479f1e5c7c705a6"}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, + "scidata": {:hex, :scidata, "0.1.11", "fe3358bac7d740374b4f2a7eff6a1cb02e5ee7f87f7cdb1e8648ad93c533165f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "90873337a9d5fe880d640517efa93d3c07e46c8ba436de44117f581800549f93"}, "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"}, From 0ef3a3b2472e4bcb8b89ec7639f39e6f95d1889e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Mon, 22 Jul 2024 11:40:24 +0200 Subject: [PATCH 10/13] Fix get_batches and docstrings --- lib/scholar/decomposition/incremental_pca.ex | 22 +++++++------------- lib/scholar/decomposition/pca.ex | 2 +- lib/scholar/neighbors/brute_knn.ex | 5 ++++- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index 7a913250..667fc0d9 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -95,14 +95,13 @@ defmodule Scholar.Decomposition.IncrementalPCA do iex> ipca = Scholar.Decomposition.IncrementalPCA.fit(batches, num_components: 2) iex> ipca.components Nx.tensor( - f64[2][4] [ - [-0.333540033447479, 0.10489666948487557, -0.8618107080105579, -0.367464336197646], - [-0.586203017375807, -0.7916955422591979, 0.158744098990766, -0.06621559023520115] + [-0.33354005217552185, 0.1048964187502861, -0.8618107080105579, -0.3674643635749817], + [-0.5862125754356384, -0.7916879057884216, 0.15874788165092468, -0.06621300429105759] ] ) iex> ipca.singular_values - Nx.tensor([77.05782028025969, 10.137862896272168]) + Nx.tensor([77.05782028025969, 10.137848854064941]) """ def fit(batches = %Stream{}, opts) do opts = NimbleOptions.validate!(opts, @opts_schema) @@ -276,19 +275,14 @@ defmodule Scholar.Decomposition.IncrementalPCA do iex> {x, _} = Scidata.Iris.download() iex> batches = x |> Nx.tensor() |> Nx.to_batched(10) - iex> x = Nx.tensor( - [ - [5.2, 2.6, 2.475, 0.7], - [6.1, 3.2, 3.95, 1.3], - [7.0, 3.8, 5.425, 1.9] - ] - ) + iex> ipca = Scholar.Decomposition.IncrementalPCA.fit(batches, num_components: 2) + iex> x = Nx.tensor([[5.2, 2.6, 2.475, 0.7], [6.1, 3.2, 3.95, 1.3], [7.0, 3.8, 5.425, 1.9]]) iex> Scholar.Decomposition.IncrementalPCA.transform(ipca, x) Nx.tensor( [ - [1.4564743682550334, 0.5657988895852432], - [-0.2724231831356622, -0.24238310361929516], - [-2.001320781438254, -1.050564912015664] + [1.4564743041992188, 0.5657951235771179], + [-0.27242332696914673, -0.24238374829292297], + [-2.0013210773468018, -1.0505625009536743] ] ) """ diff --git a/lib/scholar/decomposition/pca.ex b/lib/scholar/decomposition/pca.ex index dae79826..9f78c765 100644 --- a/lib/scholar/decomposition/pca.ex +++ b/lib/scholar/decomposition/pca.ex @@ -168,7 +168,7 @@ defmodule Scholar.Decomposition.PCA do singular_values: singular_values[[0..(num_components - 1)]], mean: mean, num_components: num_components, - num_features: Nx.u64(num_features), + num_features: num_features, num_samples: Nx.u64(num_samples) } end diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 6c4ce5eb..9c1f2625 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -218,7 +218,10 @@ defmodule Scholar.Neighbors.BruteKNN do defn get_batches(tensor, opts) do {size, dim} = Nx.shape(tensor) batch_size = min(size, opts[:batch_size]) - min_batch_size = if opts[:min_batch_size], do: opts[:min_batch_size], else: 0 + min_batch_size = case opts[:min_batch_size] do + nil -> 0 + b -> b + end num_batches = div(size, batch_size) leftover_size = rem(size, batch_size) From c59d8b04903df0a2c9add8c0f672bd2781417b9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Mon, 22 Jul 2024 11:42:40 +0200 Subject: [PATCH 11/13] mix format --- lib/scholar/neighbors/brute_knn.ex | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 9c1f2625..707c8d67 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -218,10 +218,13 @@ defmodule Scholar.Neighbors.BruteKNN do defn get_batches(tensor, opts) do {size, dim} = Nx.shape(tensor) batch_size = min(size, opts[:batch_size]) - min_batch_size = case opts[:min_batch_size] do - nil -> 0 - b -> b - end + + min_batch_size = + case opts[:min_batch_size] do + nil -> 0 + b -> b + end + num_batches = div(size, batch_size) leftover_size = rem(size, batch_size) From 3b5bc3399d0cdf1278f99de960525210e0148b8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Thu, 25 Jul 2024 13:34:27 +0200 Subject: [PATCH 12/13] Fix type: componenets -> components --- lib/scholar/decomposition/incremental_pca.ex | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index 667fc0d9..35cdb860 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -113,12 +113,12 @@ defmodule Scholar.Decomposition.IncrementalPCA do ) end - defp fit_batch(nil, batch, opts), do: fit_head_n(batch, opts) + defp fit_batch(nil, batch, opts), do: fit_head(batch, opts) defp fit_batch(%__MODULE__{} = model, batch, _opts), do: partial_fit(model, batch) deftransformp fit_head(x, opts) do {num_samples, num_features} = Nx.shape(x) - num_components = opts[:num_componenets] + num_components = opts[:num_components] cond do num_components > num_samples -> @@ -134,6 +134,8 @@ defmodule Scholar.Decomposition.IncrementalPCA do num_components must be less than or equal to \ num_features = #{num_features}, got #{num_components} """ + + true -> nil end fit_head_n(x, opts) From 70a3c78d3c58f198e21a0686db299b5ad0e35673 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Thu, 25 Jul 2024 13:35:47 +0200 Subject: [PATCH 13/13] mix format --- lib/scholar/decomposition/incremental_pca.ex | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/scholar/decomposition/incremental_pca.ex b/lib/scholar/decomposition/incremental_pca.ex index 35cdb860..41dbaf80 100644 --- a/lib/scholar/decomposition/incremental_pca.ex +++ b/lib/scholar/decomposition/incremental_pca.ex @@ -135,7 +135,8 @@ defmodule Scholar.Decomposition.IncrementalPCA do num_features = #{num_features}, got #{num_components} """ - true -> nil + true -> + nil end fit_head_n(x, opts)