Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added TruncatedSVD module #302

Merged
merged 11 commits into from
Oct 4, 2024
Merged

Added TruncatedSVD module #302

merged 11 commits into from
Oct 4, 2024

Conversation

norm4nn
Copy link
Contributor

@norm4nn norm4nn commented Sep 28, 2024

I've added the TruncatedSVD transformer. Currently, the only implemented option for the algorithm parameter is randomized. The arpack algorithm is yet to be implemented and it can be added in the future. I believe this module will be a valuable addition, as the PaCMAP algorithm, which is referenced in issue #238, utilizes this method.

For reference, here's the Scikit-learn implementation: Scikit-learn Truncated SVD implementation.

Copy link
Contributor

@msluszniak msluszniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good, just small comments


defnp fit_transform_n(x, opts) do
module = fit_n(x, opts)
Nx.dot(x, Nx.transpose(module.components))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use here Nx.dot/4 which automatically transposes particular tensors, depending on axes you provided

Comment on lines 215 to 220
{q, _a, _a_t, _i, _n_iter} =
while {q, a, a_t, i = Nx.tensor(1), n_iter}, Nx.less(i, n_iter) do
{q, _} = Nx.LinAlg.qr(Nx.dot(a, q))
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q))
{q, a, a_t, i + 1, n_iter}
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{q, _a, _a_t, _i, _n_iter} =
while {q, a, a_t, i = Nx.tensor(1), n_iter}, Nx.less(i, n_iter) do
{q, _} = Nx.LinAlg.qr(Nx.dot(a, q))
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q))
{q, a, a_t, i + 1, n_iter}
end
{q, _} =
while {q, {a, a_t, i = Nx.tensor(1), n_iter}}, Nx.less(i, n_iter) do
{q, _} = Nx.LinAlg.qr(Nx.dot(a, q))
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q))
{q, {a, a_t, i + 1, n_iter}}
end

{u, sigma, vt} = randomized_svd(x, opts)
{_u, vt} = Scholar.Decomposition.PCA.flip_svd(u, vt)

x_transformed = Nx.dot(x, Nx.transpose(vt))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as below, use Nx.dot/4 here

Copy link
Member

@krstopro krstopro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good so far. I added few comments, will probably have another look soon.

@@ -471,7 +471,8 @@ defmodule Scholar.Decomposition.PCA do
end
end

defnp flip_svd(u, v) do
@doc false
defn flip_svd(u, v) do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would pull this out in a separate module in the same folder, e.g. Scholar.Decomposition.Utils or Scholar.Decomposition.Shared. See how it's done in Scholar.Neighbors.Utils.

]

tsvd_schema = [
n_components: [
Copy link
Member

@krstopro krstopro Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this to num_components to be consistent with the rest of Scholar.

type: :pos_integer,
doc: "Desired dimensionality of output data."
],
n_iter: [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, num_iters.

type: :pos_integer,
doc: "Number of iterations for randomized SVD solver."
],
n_oversamples: [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same here, num_oversamples.

Copy link
Member

@krstopro krstopro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more suggestion.

You should also check that num_components is less than or equal to num_samples.

type: :pos_integer,
doc: "Number of oversamples for randomized SVD solver."
],
seed: [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should pass key here, which is of type {:custom, Scholar.Options, :key}.
See how it's done in e.g. Scholar.Clustering.KMeans.

Comment on lines 186 to 187
q_t = Nx.transpose(q)
b = Nx.dot(q_t, m)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since q_t is being used only for the Nx.dot call below, you can skip calculating it by using Nx.dot/4: b = Nx.dot(q, [-2], m, [-2])

@josevalim josevalim merged commit 2a601cc into elixir-nx:main Oct 4, 2024
2 checks passed
@josevalim
Copy link
Contributor

💚 💙 💜 💛 ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants