diff --git a/lib/scholar/covariance/ledoit_wolf.ex b/lib/scholar/covariance/ledoit_wolf.ex index 6aa567b3..a9debc24 100644 --- a/lib/scholar/covariance/ledoit_wolf.ex +++ b/lib/scholar/covariance/ledoit_wolf.ex @@ -13,7 +13,7 @@ defmodule Scholar.Covariance.LedoitWolf do defstruct [:covariance, :shrinkage, :location] opts_schema = [ - assume_centered: [ + assume_centered?: [ default: false, type: :boolean, doc: """ @@ -93,7 +93,7 @@ defmodule Scholar.Covariance.LedoitWolf do iex> key = Nx.Random.key(0) iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), shape: {10}, type: :f32) - iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered: true) + iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered?: true) iex> cov.covariance #Nx.Tensor< f32[3][3] @@ -110,7 +110,7 @@ defmodule Scholar.Covariance.LedoitWolf do end defnp fit_n(x, opts) do - {x, location} = center(x, opts) + {x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?]) {covariance, shrinkage} = ledoit_wolf(x) @@ -122,23 +122,6 @@ defmodule Scholar.Covariance.LedoitWolf do } end - defnp center(x, opts) do - x = - case Nx.shape(x) do - {_} -> Nx.new_axis(x, 1) - _ -> x - end - - location = - if opts[:assume_centered] do - 0 - else - Nx.mean(x, axes: [0]) - end - - {x - location, location} - end - defnp ledoit_wolf(x) do case Nx.shape(x) do {_n, 1} -> @@ -149,23 +132,6 @@ defmodule Scholar.Covariance.LedoitWolf do end end - defnp empirical_covariance(x) do - n = Nx.axis_size(x, 0) - - covariance = Nx.dot(x, [0], x, [0]) / n - - case Nx.shape(covariance) do - {} -> Nx.reshape(covariance, {1, 1}) - _ -> covariance - end - end - - defnp trace(x) do - x - |> Nx.take_diagonal() - |> Nx.sum() - end - defnp ledoit_wolf_shrinkage(x) do case Nx.shape(x) do {_, 1} -> @@ -182,9 +148,9 @@ defmodule Scholar.Covariance.LedoitWolf do defnp ledoit_wolf_shrinkage_complex(x) do {num_samples, num_features} = Nx.shape(x) - emp_cov = empirical_covariance(x) + emp_cov = Scholar.Covariance.Utils.empirical_covariance(x) - emp_cov_trace = trace(emp_cov) + emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov) mu = Nx.sum(emp_cov_trace) / num_features flatten_delta = Nx.flatten(emp_cov) diff --git a/lib/scholar/covariance/shrunk_covariance.ex b/lib/scholar/covariance/shrunk_covariance.ex new file mode 100644 index 00000000..0302a350 --- /dev/null +++ b/lib/scholar/covariance/shrunk_covariance.ex @@ -0,0 +1,119 @@ +defmodule Scholar.Covariance.ShrunkCovariance do + @moduledoc """ + Covariance estimator with shrinkage. + """ + import Nx.Defn + + @derive {Nx.Container, containers: [:covariance, :location]} + defstruct [:covariance, :location] + + opts_schema = [ + assume_centered?: [ + default: false, + type: :boolean, + doc: """ + If `true`, data will not be centered before computation. + Useful when working with data whose mean is almost, but not exactly + zero. + If `false`, data will be centered before computation. + """ + ], + shrinkage: [ + default: 0.1, + type: :float, + doc: "Coefficient in the convex combination used for the computation + of the shrunk estimate. Range is [0, 1]." + ] + ] + + @opts_schema NimbleOptions.new!(opts_schema) + @doc """ + Fit the shrunk covariance model to `x`. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + The function returns a struct with the following parameters: + + * `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix. + * `:location` - Tensor of shape `{num_features,}`. + Estimated location, i.e. the estimated mean. + + ## Examples + + iex> key = Nx.Random.key(0) + iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32) + iex> model = Scholar.Covariance.ShrunkCovariance.fit(x) + iex> model.covariance + #Nx.Tensor< + f32[2][2] + [ + [0.7721845507621765, 0.19141492247581482], + [0.19141492247581482, 0.33952537178993225] + ] + > + iex> model.location + #Nx.Tensor< + f32[2] + [0.18202415108680725, -0.09216632694005966] + > + iex> key = Nx.Random.key(0) + iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32) + iex> model = Scholar.Covariance.ShrunkCovariance.fit(x, shrinkage: 0.4) + iex> model.covariance + #Nx.Tensor< + f32[2][2] + [ + [0.7000747323036194, 0.1276099532842636], + [0.1276099532842636, 0.41163527965545654] + ] + > + iex> model.location + #Nx.Tensor< + f32[2] + [0.18202415108680725, -0.09216632694005966] + > + """ + + deftransform fit(x, opts \\ []) do + fit_n(x, NimbleOptions.validate!(opts, @opts_schema)) + end + + defnp fit_n(x, opts) do + shrinkage = opts[:shrinkage] + + if shrinkage < 0 or shrinkage > 1 do + raise ArgumentError, + """ + expected :shrinkage option to be in [0, 1] range, \ + got shrinkage: #{inspect(Nx.shape(x))}\ + """ + end + + {x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?]) + + covariance = + Scholar.Covariance.Utils.empirical_covariance(x) + |> shrunk_covariance(shrinkage) + + %__MODULE__{ + covariance: covariance, + location: location + } + end + + defnp shrunk_covariance(emp_cov, shrinkage) do + num_features = Nx.axis_size(emp_cov, 1) + shrunk_cov = (1.0 - shrinkage) * emp_cov + emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov) + mu = Nx.sum(emp_cov_trace) / num_features + + mask = Nx.iota(Nx.shape(shrunk_cov)) + selector = Nx.remainder(mask, num_features + 1) == 0 + + shrunk_cov + shrinkage * mu * selector + end +end diff --git a/lib/scholar/covariance/utils.ex b/lib/scholar/covariance/utils.ex new file mode 100644 index 00000000..a6bff021 --- /dev/null +++ b/lib/scholar/covariance/utils.ex @@ -0,0 +1,39 @@ +defmodule Scholar.Covariance.Utils do + @moduledoc false + import Nx.Defn + require Nx + + defn center(x, assume_centered? \\ false) do + x = + case Nx.shape(x) do + {_} -> Nx.new_axis(x, 1) + _ -> x + end + + location = + if assume_centered? do + 0 + else + Nx.mean(x, axes: [0]) + end + + {x - location, location} + end + + defn empirical_covariance(x) do + n = Nx.axis_size(x, 0) + + covariance = Nx.dot(x, [0], x, [0]) / n + + case Nx.shape(covariance) do + {} -> Nx.reshape(covariance, {1, 1}) + _ -> covariance + end + end + + defn trace(x) do + x + |> Nx.take_diagonal() + |> Nx.sum() + end +end diff --git a/test/scholar/covariance/ledoit_wolf_test.exs b/test/scholar/covariance/ledoit_wolf_test.exs index 27e1893c..d7906b07 100644 --- a/test/scholar/covariance/ledoit_wolf_test.exs +++ b/test/scholar/covariance/ledoit_wolf_test.exs @@ -40,7 +40,7 @@ defmodule Scholar.Covariance.LedoitWolfTest do ) end - test "fit test - :assume_centered is true" do + test "fit test - :assume_centered? is true" do key = key() {x, _new_key} = @@ -52,7 +52,7 @@ defmodule Scholar.Covariance.LedoitWolfTest do type: :f32 ) - model = LedoitWolf.fit(x, assume_centered: true) + model = LedoitWolf.fit(x, assume_centered?: true) assert_all_close( model.covariance, diff --git a/test/scholar/covariance/shrunk_covariance_test.exs b/test/scholar/covariance/shrunk_covariance_test.exs new file mode 100644 index 00000000..e11bc676 --- /dev/null +++ b/test/scholar/covariance/shrunk_covariance_test.exs @@ -0,0 +1,150 @@ +defmodule Scholar.Covariance.ShrunkCovarianceTest do + use Scholar.Case, async: true + alias Scholar.Covariance.ShrunkCovariance + doctest ShrunkCovariance + + defp key do + Nx.Random.key(1) + end + + test "fit test - all default options" do + key = key() + + {x, _new_key} = + Nx.Random.multivariate_normal( + key, + Nx.tensor([0.0, 0.0, 0.0]), + Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), + shape: {10}, + type: :f32 + ) + + model = ShrunkCovariance.fit(x) + + assert_all_close( + model.covariance, + Nx.tensor([ + [2.0949244499206543, -0.13400490581989288, 0.5413897037506104], + [-0.13400490581989288, 1.2940725088119507, 0.0621684193611145], + [0.5413897037506104, 0.0621684193611145, 0.9303621053695679] + ]), + atol: 1.0e-3 + ) + + assert_all_close( + model.location, + Nx.tensor([-1.015519142150879, -0.4495307505130768, 0.06475571542978287]), + atol: 1.0e-3 + ) + end + + test "fit test - :assume_centered? is true" do + key = key() + + {x, _new_key} = + Nx.Random.multivariate_normal( + key, + Nx.tensor([0.0, 0.0, 0.0]), + Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), + shape: {10}, + type: :f32 + ) + + model = ShrunkCovariance.fit(x, assume_centered?: true) + + assert_all_close( + model.covariance, + Nx.tensor([ + [3.0643274784088135, 0.27685147523880005, 0.4822050631046295], + [0.27685147523880005, 1.5171942710876465, 0.03596973791718483], + [0.4822050631046295, 0.03596973791718483, 0.975387692451477] + ]), + atol: 1.0e-3 + ) + + assert_all_close(model.location, Nx.tensor(0), atol: 1.0e-3) + end + + test "fit test - :shrinkage" do + key = key() + + {x, _new_key} = + Nx.Random.multivariate_normal( + key, + Nx.tensor([0.0, 0.0, 0.0]), + Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), + shape: {10}, + type: :f32 + ) + + model = ShrunkCovariance.fit(x, shrinkage: 0.8) + + assert_all_close( + model.covariance, + Nx.tensor([ + [1.5853726863861084, -0.029778867959976196, 0.12030883133411407], + [-0.029778867959976196, 1.4074056148529053, 0.013815204612910748], + [0.12030883133411407, 0.013815204612910748, 1.3265810012817383] + ]), + atol: 1.0e-3 + ) + + assert_all_close( + model.location, + Nx.tensor([-1.015519142150879, -0.4495307505130768, 0.06475571542978287]), + atol: 1.0e-3 + ) + end + + test "fit test 2" do + key = key() + + {x, _new_key} = + Nx.Random.multivariate_normal( + key, + Nx.tensor([0.0, 0.0]), + Nx.tensor([[2.2, 1.5], [0.7, 1.1]]), + shape: {50}, + type: :f32 + ) + + model = ShrunkCovariance.fit(x) + + assert_all_close( + model.covariance, + Nx.tensor([ + [1.9810796976089478, 0.3997809886932373], + [0.3997809886932373, 1.0836023092269897] + ]), + atol: 1.0e-3 + ) + + assert_all_close(model.location, Nx.tensor([0.06882287561893463, 0.13750331103801727]), + atol: 1.0e-3 + ) + end + + test "fit test - 1 dim x" do + key = key() + + {x, _new_key} = + Nx.Random.multivariate_normal(key, Nx.tensor([0.0]), Nx.tensor([[0.4]]), + shape: {15}, + type: :f32 + ) + + x = Nx.flatten(x) + + model = ShrunkCovariance.fit(x) + + assert_all_close( + model.covariance, + Nx.tensor([ + [0.5322133302688599] + ]), + atol: 1.0e-3 + ) + + assert_all_close(model.location, Nx.tensor([0.060818854719400406]), atol: 1.0e-3) + end +end