Skip to content

Commit

Permalink
Add TruncatedSVD module (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn authored Oct 4, 2024
1 parent 975938a commit 2a601cc
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 11 deletions.
14 changes: 3 additions & 11 deletions lib/scholar/decomposition/pca.ex
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ defmodule Scholar.Decomposition.PCA do
x_centered = x - mean
variance = Nx.sum(x_centered * x_centered / (num_samples - 1), axes: [0])
{u, s, vt} = Nx.LinAlg.svd(x_centered, full_matrices?: false)
{_, vt} = flip_svd(u, vt)
{_, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt)
components = vt[0..(num_components - 1)]
explained_variance = s * s / (num_samples - 1)

Expand Down Expand Up @@ -311,7 +311,7 @@ defmodule Scholar.Decomposition.PCA do
)

{u, s, vt} = Nx.LinAlg.svd(matrix, full_matrices?: false)
{_, vt} = flip_svd(u, vt)
{_, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt)
updated_components = vt[0..(num_components - 1)]
updated_singular_values = s[0..(num_components - 1)]

Expand Down Expand Up @@ -461,7 +461,7 @@ defmodule Scholar.Decomposition.PCA do
mean = Nx.mean(x, axes: [0])
x_centered = x - mean
{u, s, vt} = Nx.LinAlg.svd(x_centered, full_matrices?: false)
{u, _} = flip_svd(u, vt)
{u, _} = Scholar.Decomposition.Utils.flip_svd(u, vt)
u = u[[.., 0..(num_components - 1)]]

if opts[:whiten?] do
Expand All @@ -470,12 +470,4 @@ defmodule Scholar.Decomposition.PCA do
u * s[0..(num_components - 1)]
end
end

defnp flip_svd(u, v) do
max_abs_cols_idx = u |> Nx.abs() |> Nx.argmax(axis: 0, keep_axis: true)
signs = u |> Nx.take_along_axis(max_abs_cols_idx, axis: 0) |> Nx.sign() |> Nx.squeeze()
u = u * signs
v = v * Nx.new_axis(signs, -1)
{u, v}
end
end
236 changes: 236 additions & 0 deletions lib/scholar/decomposition/truncated_svd.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
defmodule Scholar.Decomposition.TruncatedSVD do
@moduledoc """
Dimensionality reduction using truncated SVD (aka LSA).
This transformer performs linear dimensionality reduction by means of
truncated singular value decomposition (SVD).
"""
import Nx.Defn

@derive {Nx.Container,
containers: [
:components,
:explained_variance,
:explained_variance_ratio,
:singular_values
]}
defstruct [
:components,
:explained_variance,
:explained_variance_ratio,
:singular_values
]

tsvd_schema = [
num_components: [
default: 2,
type: :pos_integer,
doc: "Desired dimensionality of output data."
],
num_iter: [
default: 5,
type: :pos_integer,
doc: "Number of iterations for randomized SVD solver."
],
num_oversamples: [
default: 10,
type: :pos_integer,
doc: "Number of oversamples for randomized SVD solver."
],
key: [
type: {:custom, Scholar.Options, :key, []},
doc: """
Key for random tensor generation.
If the key is not provided, it is set to `Nx.Random.key(System.system_time())`.
"""
]
]

@tsvd_schema NimbleOptions.new!(tsvd_schema)

@doc """
Fit model on training data X.
## Options
#{NimbleOptions.docs(@tsvd_schema)}
## Return Values
The function returns a struct with the following parameters:
* `:components` - tensor of shape `{num_components, num_features}`
The right singular vectors of the input data.
* `:explained_variance` - tensor of shape `{num_components}`
The variance of the training samples transformed by a projection to
each component.
* `:explained_variance_ratio` - tensor of shape `{num_components}`
Percentage of variance explained by each of the selected components.
* `:singular_values` - ndarray of shape `{num_components}`
The singular values corresponding to each of the selected components.
## Examples
iex> key = Nx.Random.key(0)
iex> x = Nx.tensor([[0, 0], [1, 0], [1, 1], [3, 3], [4, 4.5]])
iex> tsvd = Scholar.Decomposition.TruncatedSVD.fit(x, num_components: 2, key: key)
iex> tsvd.components
#Nx.Tensor<
f32[2][2]
[
[0.6871105432510376, 0.7265529036521912],
[0.7265529036521912, -0.6871105432510376]
]
>
iex> tsvd.singular_values
#Nx.Tensor<
f32[2]
[7.528080940246582, 0.7601959705352783]
>
"""

deftransform fit(x, opts \\ []) do
opts = NimbleOptions.validate!(opts, @tsvd_schema)
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end)
fit_n(x, key, opts)
end

@doc """
Fit model to X and perform dimensionality reduction on X.
## Options
#{NimbleOptions.docs(@tsvd_schema)}
## Return Values
X_new tensor of shape `{num_samples, num_components}` - reduced version of X.
## Examples
iex> key = Nx.Random.key(0)
iex> x = Nx.tensor([[0, 0], [1, 0], [1, 1], [3, 3], [4, 4.5]])
iex> Scholar.Decomposition.TruncatedSVD.fit_transform(x, num_components: 2, key: key)
#Nx.Tensor<
f32[5][2]
[
[0.0, 0.0],
[0.6871105432510376, 0.7265529036521912],
[1.413663387298584, 0.039442360401153564],
[4.240990161895752, 0.1183270812034607],
[6.017930030822754, -0.18578583002090454]
]
>
"""

deftransform fit_transform(x, opts \\ []) do
opts = NimbleOptions.validate!(opts, @tsvd_schema)
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end)
fit_transform_n(x, key, opts)
end

defnp fit_n(x, key, opts) do
if Nx.rank(x) != 2 do
raise ArgumentError, "expected x to have rank equal to: 2, got: #{inspect(Nx.rank(x))}"
end

num_components = opts[:num_components]
{num_samples, num_features} = Nx.shape(x)

cond do
num_components > num_features ->
raise ArgumentError,
"""
num_components must be less than or equal to \
num_features = #{num_features}, got #{num_components}
"""

num_components > num_samples ->
raise ArgumentError,
"""
num_components must be less than or equal to \
num_samples = #{num_samples}, got #{num_components}
"""

true ->
nil
end

{u, sigma, vt} = randomized_svd(x, key, opts)
{_u, vt} = Scholar.Decomposition.Utils.flip_svd(u, vt)

x_transformed = Nx.dot(x, [1], vt, [0])
explained_variance = Nx.variance(x_transformed, axes: [0])
full_variance = Nx.variance(x, axes: [0]) |> Nx.sum()
explained_variance_ratio = explained_variance / full_variance

%__MODULE__{
components: vt,
explained_variance: explained_variance,
explained_variance_ratio: explained_variance_ratio,
singular_values: sigma
}
end

defnp fit_transform_n(x, key, opts) do
module = fit_n(x, key, opts)
Nx.dot(x, [1], module.components, [0])
end

defnp randomized_svd(m, key, opts) do
num_components = opts[:num_components]
num_oversamples = opts[:num_oversamples]
num_iter = opts[:num_iter]
n_random = num_components + num_oversamples
{num_samples, num_features} = Nx.shape(m)

transpose = num_samples < num_components

m =
if Nx.equal(transpose, 1) do
Nx.transpose(m)
else
m
end

q = randomized_range_finder(m, key, size: n_random, num_iter: num_iter)

b = Nx.dot(q, [-2], m, [-2])
{uhat, s, vt} = Nx.LinAlg.svd(b)
u = Nx.dot(q, uhat)
vt = Nx.slice(vt, [0, 0], [num_components, num_features])
s = Nx.slice(s, [0], [num_components])
u = Nx.slice(u, [0, 0], [num_samples, num_components])

if Nx.equal(transpose, 1) do
{Nx.transpose(vt), s, Nx.transpose(u)}
else
{u, s, vt}
end
end

defn randomized_range_finder(a, key, opts) do
size = opts[:size]
num_iter = opts[:num_iter]

{_, a_cols} = Nx.shape(a)
{q, _} = Nx.Random.normal(key, shape: {a_cols, size})
a_t = Nx.transpose(a)

{q, _} = Nx.LinAlg.qr(Nx.dot(a, q))
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q))

{q, _} =
while {q, {a, a_t, i = Nx.tensor(1), num_iter}}, Nx.less(i, num_iter) do
{q, _} = Nx.LinAlg.qr(Nx.dot(a, q))
{q, _} = Nx.LinAlg.qr(Nx.dot(a_t, q))
{q, {a, a_t, i + 1, num_iter}}
end

{q, _} = Nx.LinAlg.qr(Nx.dot(a, q))
q
end
end
13 changes: 13 additions & 0 deletions lib/scholar/decomposition/utils.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
defmodule Scholar.Decomposition.Utils do
@moduledoc false
import Nx.Defn
require Nx

defn flip_svd(u, v) do
max_abs_cols_idx = u |> Nx.abs() |> Nx.argmax(axis: 0, keep_axis: true)
signs = u |> Nx.take_along_axis(max_abs_cols_idx, axis: 0) |> Nx.sign() |> Nx.squeeze()
u = u * signs
v = v * Nx.new_axis(signs, -1)
{u, v}
end
end

0 comments on commit 2a601cc

Please sign in to comment.