Skip to content

Commit

Permalink
add docs and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
santiago-imelio committed Dec 13, 2023
1 parent c19c200 commit eb644e2
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 5 deletions.
107 changes: 103 additions & 4 deletions lib/scholar/preprocessing/standard_scaler.ex
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
defmodule Scholar.Preprocessing.StandardScaler do
@moduledoc """
Standardizes the tensor by removing the mean and scaling to unit variance.
#{~S'''
Formula for input tensor $x$:
$$
z = \frac{x - \mu}{\sigma}
$$
Where $\mu$ is the mean of the samples, and $\sigma$ is the standard deviation.
Standardization can be helpful in cases where the data follows
a Gaussian distribution (or Normal distribution) without outliers.
'''}
Centering and scaling happen independently on each feature by computing the relevant
statistics on the samples in the training set. Mean and standard deviation are then
stored to be used on new samples.
"""

import Nx.Defn

defstruct [:deviation, :mean]
@derive {Nx.Container, containers: [:standard_deviation, :mean]}
defstruct [:standard_deviation, :mean]

opts_schema = [
axes: [
Expand All @@ -15,25 +34,105 @@ defmodule Scholar.Preprocessing.StandardScaler do

@opts_schema NimbleOptions.new!(opts_schema)

@doc """
Compute the standard deviation and mean of samples to be used for later scaling.
## Options
#{NimbleOptions.docs(@opts_schema)}
## Return values
Returns a struct with the following parameters:
* `standard_deviation`: the calculated standard deviation of samples.
* `mean`: the calculated mean of samples.
## Examples
iex> Scholar.Preprocessing.StandardScaler.fit(Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]))
%Scholar.Preprocessing.StandardScaler{
standard_deviation: #Nx.Tensor<
f32[1][1]
[
[1.0657403469085693]
]
>,
mean: #Nx.Tensor<
f32[1][1]
[
[0.4444444477558136]
]
>
}
"""
deftransform fit(tensor, opts \\ []) do
NimbleOptions.validate!(opts, @opts_schema)
{std, mean} = fit_n(tensor, opts)

%__MODULE__{deviation: std, mean: mean}
%__MODULE__{standard_deviation: std, mean: mean}
end

defnp fit_n(tensor, opts) do
std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.select(Nx.equal(std, 0), 0.0, mean_reduced)
mean_reduced = Nx.select(std == 0, 0.0, mean_reduced)

{std, mean_reduced}
end

deftransform transform(%__MODULE__{deviation: std, mean: mean}, tensor) do
@doc """
Performs the standardization of the tensor using a fitted scaler.
## Examples
iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
iex> scaler = Scholar.Preprocessing.StandardScaler.fit(t)
%Scholar.Preprocessing.StandardScaler{
standard_deviation: #Nx.Tensor<
f32[1][1]
[
[1.0657403469085693]
]
>,
mean: #Nx.Tensor<
f32[1][1]
[
[0.4444444477558136]
]
>
}
iex> Scholar.Preprocessing.StandardScaler.transform(scaler, t)
#Nx.Tensor<
f32[3][3]
[
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
]
>
"""
defn transform(%__MODULE__{standard_deviation: std, mean: mean}, tensor) do
scale(tensor, std, mean)
end

@doc """
Standardizes the tensor by removing the mean and scaling to unit variance.
## Examples
iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
iex> Scholar.Preprocessing.StandardScaler.fit_transform(t)
#Nx.Tensor<
f32[3][3]
[
[0.5212860703468323, -1.3553436994552612, 1.4596009254455566],
[1.4596009254455566, -0.4170288145542145, -0.4170288145542145],
[-0.4170288145542145, 0.5212860703468323, -1.3553436994552612]
]
>
"""
defn fit_transform(tensor, opts \\ []) do
tensor
|> fit(opts)
Expand Down
4 changes: 3 additions & 1 deletion test/scholar/preprocessing/standard_scaler_test.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
defmodule StandardScalerTest do
defmodule Scholar.Preprocessing.StandardScalerTest do
use Scholar.Case, async: true
alias Scholar.Preprocessing.StandardScaler

doctest StandardScaler

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.15.6, 26.1, true)

doctest Scholar.Preprocessing.StandardScaler.transform/2 (2) (Scholar.Preprocessing.StandardScalerTest)

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.15.6, 26.1, true)

doctest Scholar.Preprocessing.StandardScaler.fit/2 (1) (Scholar.Preprocessing.StandardScalerTest)

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

doctest Scholar.Preprocessing.StandardScaler.fit/2 (1) (Scholar.Preprocessing.StandardScalerTest)

Check failure on line 5 in test/scholar/preprocessing/standard_scaler_test.exs

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

doctest Scholar.Preprocessing.StandardScaler.transform/2 (2) (Scholar.Preprocessing.StandardScalerTest)

describe "fit_transform/2" do
test "applies standard scaling to data" do
data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
Expand Down

0 comments on commit eb644e2

Please sign in to comment.