Skip to content

Commit

Permalink
Fix TruncatedSVD (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn authored Oct 28, 2024
1 parent 2a601cc commit 6c02f4f
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions lib/scholar/decomposition/truncated_svd.ex
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ defmodule Scholar.Decomposition.TruncatedSVD do
f32[2]
[7.528080940246582, 0.7601959705352783]
>
"""

deftransform fit(x, opts \\ []) do
Expand Down Expand Up @@ -124,6 +125,19 @@ defmodule Scholar.Decomposition.TruncatedSVD do
[6.017930030822754, -0.18578583002090454]
]
>
iex> key = Nx.Random.key(0)
iex> x = Nx.tensor([[0, 0, 3], [1, 0, 3], [1, 1, 3], [3, 3, 3], [4, 4.5, 3]])
iex> tsvd = Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key)
#Nx.Tensor<
f32[5][2]
[
[1.9478826522827148, 2.260593891143799],
[2.481153964996338, 1.906071662902832],
[3.023407220840454, 1.352442979812622],
[5.174456596374512, -0.46385863423347473],
[6.521108150482178, -1.6488237380981445]
]
>
"""

deftransform fit_transform(x, opts \\ []) do
Expand Down Expand Up @@ -162,7 +176,7 @@ defmodule Scholar.Decomposition.TruncatedSVD do
{u, sigma, vt} = randomized_svd(x, key, opts)
{_u, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt)

x_transformed = Nx.dot(x, [1], vt, [0])
x_transformed = Nx.dot(x, Nx.transpose(vt))
explained_variance = Nx.variance(x_transformed, axes: [0])
full_variance = Nx.variance(x, axes: [0]) |> Nx.sum()
explained_variance_ratio = explained_variance / full_variance
Expand All @@ -177,7 +191,7 @@ defmodule Scholar.Decomposition.TruncatedSVD do

defnp fit_transform_n(x, key, opts) do
module = fit_n(x, key, opts)
Nx.dot(x, [1], module.components, [0])
Nx.dot(x, Nx.transpose(module.components))
end

defnp randomized_svd(m, key, opts) do
Expand All @@ -198,7 +212,8 @@ defmodule Scholar.Decomposition.TruncatedSVD do

q = randomized_range_finder(m, key, size: n_random, num_iter: num_iter)

b = Nx.dot(q, [-2], m, [-2])
q_t = Nx.transpose(q)
b = Nx.dot(q_t, m)
{uhat, s, vt} = Nx.LinAlg.svd(b)
u = Nx.dot(q, uhat)
vt = Nx.slice(vt, [0, 0], [num_components, num_features])
Expand Down

0 comments on commit 6c02f4f

Please sign in to comment.