-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added TruncatedSVD module #302
Conversation
Co-authored-by: José Valim <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good, just small comments
|
||
defnp fit_transform_n(x, opts) do | ||
module = fit_n(x, opts) | ||
Nx.dot(x, Nx.transpose(module.components)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use here Nx.dot/4
which automatically transposes particular tensors, depending on axes you provided
{q, _a, _a_t, _i, _n_iter} = | ||
while {q, a, a_t, i = Nx.tensor(1), n_iter}, Nx.less(i, n_iter) do | ||
{q, _} = Nx.LinAlg.qr(Nx.dot(a, q)) | ||
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q)) | ||
{q, a, a_t, i + 1, n_iter} | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{q, _a, _a_t, _i, _n_iter} = | |
while {q, a, a_t, i = Nx.tensor(1), n_iter}, Nx.less(i, n_iter) do | |
{q, _} = Nx.LinAlg.qr(Nx.dot(a, q)) | |
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q)) | |
{q, a, a_t, i + 1, n_iter} | |
end | |
{q, _} = | |
while {q, {a, a_t, i = Nx.tensor(1), n_iter}}, Nx.less(i, n_iter) do | |
{q, _} = Nx.LinAlg.qr(Nx.dot(a, q)) | |
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q)) | |
{q, {a, a_t, i + 1, n_iter}} | |
end |
{u, sigma, vt} = randomized_svd(x, opts) | ||
{_u, vt} = Scholar.Decomposition.PCA.flip_svd(u, vt) | ||
|
||
x_transformed = Nx.dot(x, Nx.transpose(vt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as below, use Nx.dot/4
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good so far. I added few comments, will probably have another look soon.
lib/scholar/decomposition/pca.ex
Outdated
@@ -471,7 +471,8 @@ defmodule Scholar.Decomposition.PCA do | |||
end | |||
end | |||
|
|||
defnp flip_svd(u, v) do | |||
@doc false | |||
defn flip_svd(u, v) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would pull this out in a separate module in the same folder, e.g. Scholar.Decomposition.Utils
or Scholar.Decomposition.Shared
. See how it's done in Scholar.Neighbors.Utils
.
] | ||
|
||
tsvd_schema = [ | ||
n_components: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rename this to num_components
to be consistent with the rest of Scholar
.
type: :pos_integer, | ||
doc: "Desired dimensionality of output data." | ||
], | ||
n_iter: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, num_iters
.
type: :pos_integer, | ||
doc: "Number of iterations for randomized SVD solver." | ||
], | ||
n_oversamples: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And same here, num_oversamples
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more suggestion.
You should also check that num_components
is less than or equal to num_samples
.
type: :pos_integer, | ||
doc: "Number of oversamples for randomized SVD solver." | ||
], | ||
seed: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should pass key
here, which is of type {:custom, Scholar.Options, :key}
.
See how it's done in e.g. Scholar.Clustering.KMeans
.
q_t = Nx.transpose(q) | ||
b = Nx.dot(q_t, m) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since q_t is being used only for the Nx.dot
call below, you can skip calculating it by using Nx.dot/4
: b = Nx.dot(q, [-2], m, [-2])
💚 💙 💜 💛 ❤️ |
I've added the
TruncatedSVD
transformer. Currently, the only implemented option for the algorithm parameter is randomized. The arpack algorithm is yet to be implemented and it can be added in the future. I believe this module will be a valuable addition, as the PaCMAP algorithm, which is referenced in issue #238, utilizes this method.For reference, here's the Scikit-learn implementation: Scikit-learn Truncated SVD implementation.