Skip to content

Commit

Permalink
Fix doctests by avoiding nesting #Nx.Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 10, 2024
1 parent 1765930 commit 4dec8ba
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 37 deletions.
43 changes: 19 additions & 24 deletions lib/scholar/neighbors/rnn_classifier.ex
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,14 @@ defmodule Scholar.Neighbors.RNNClassifier do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> Scholar.Neighbors.RNNClassifier.fit(x, y, num_classes: 2)
%Scholar.Neighbors.RNNClassifier{
data: #Nx.Tensor<
s64[4][2]
[
[1, 2],
[2, 4],
[1, 3],
[2, 5]
]
>,
labels: #Nx.Tensor<
s64[4]
[1, 0, 1, 1]
>,
data: Nx.tensor([
[1, 2],
[2, 4],
[1, 3],
[2, 5]
]),
labels: Nx.tensor([1, 0, 1, 1]),
weights: :uniform,
num_classes: 2,
metric: &Scholar.Metrics.Distance.pairwise_minkowski/2,
Expand Down Expand Up @@ -177,19 +170,20 @@ defmodule Scholar.Neighbors.RNNClassifier do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> model = Scholar.Neighbors.RNNClassifier.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RNNClassifier.predict_probability(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{#Nx.Tensor<
iex> {probs, mask} = Scholar.Neighbors.RNNClassifier.predict_probability(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
iex> probs
#Nx.Tensor<
f32[2][2]
[
[0.5, 0.5],
[0.0, 1.0]
]
>,
>
iex> mask
#Nx.Tensor<
u8[2]
[0, 0]
>}
>
"""
defn predict_probability(
%__MODULE__{
Expand Down Expand Up @@ -237,22 +231,23 @@ defmodule Scholar.Neighbors.RNNClassifier do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> model = Scholar.Neighbors.RNNClassifier.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RNNClassifier.radius_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{#Nx.Tensor<
iex> {distances, mask} = Scholar.Neighbors.RNNClassifier.radius_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
iex> distances
#Nx.Tensor<
f32[2][4]
[
[2.469818353652954, 0.3162313997745514, 1.5811394453048706, 0.7071067690849304],
[0.10000114142894745, 2.1931710243225098, 1.0049877166748047, 3.132091760635376]
]
>,
>
iex> mask
#Nx.Tensor<
u8[2][4]
[
[0, 1, 0, 1],
[1, 0, 0, 0]
]
>}
>
"""
defn radius_neighbors(%__MODULE__{metric: metric, radius: radius, data: data}, x) do
distances = metric.(x, data)
Expand Down
22 changes: 9 additions & 13 deletions lib/scholar/neighbors/rnn_regressor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,16 @@ defmodule Scholar.Neighbors.RNNRegressor do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> Scholar.Neighbors.RNNRegressor.fit(x, y, num_classes: 2)
%Scholar.Neighbors.RNNRegressor{
data: #Nx.Tensor<
s64[4][2]
data: Nx.tensor(
[
[1, 2],
[2, 4],
[1, 3],
[2, 5]
]
>,
labels: #Nx.Tensor<
s64[4]
[1, 0, 1, 1]
>,
),
labels: Nx.tensor([1, 0, 1, 1]),
weights: :uniform,
num_classes: 2,
metric: &Scholar.Metrics.Distance.pairwise_minkowski/2,
Expand Down Expand Up @@ -202,22 +197,23 @@ defmodule Scholar.Neighbors.RNNRegressor do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> model = Scholar.Neighbors.RNNRegressor.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RNNRegressor.radius_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{#Nx.Tensor<
iex> {distances, mask} = Scholar.Neighbors.RNNRegressor.radius_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
iex> distances
#Nx.Tensor<
f32[2][4]
[
[2.469818353652954, 0.3162313997745514, 1.5811394453048706, 0.7071067690849304],
[0.10000114142894745, 2.1931710243225098, 1.0049877166748047, 3.132091760635376]
]
>,
>
iex> mask
#Nx.Tensor<
u8[2][4]
[
[0, 1, 0, 1],
[1, 0, 0, 0]
]
>}
>
"""
defn radius_neighbors(%__MODULE__{metric: metric, radius: radius, data: data}, x) do
distances = metric.(x, data)
Expand Down

0 comments on commit 4dec8ba

Please sign in to comment.