Skip to content

Commit

Permalink
raise error when k > n
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed Apr 15, 2024
1 parent 96e8ed2 commit ec08bc4
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions lib/scholar/neighbors/brute_knn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,20 @@ defmodule Scholar.Neighbors.BruteKNN do
deftransform fit(data, opts) do
if Nx.rank(data) != 2 do
raise ArgumentError,
"expected input tensor to have shape {n_samples, n_features},
"expected input tensor to have shape {num_samples, num_features},
got tensor with shape: #{inspect(Nx.shape(data))}"
end

opts = NimbleOptions.validate!(opts, @opts_schema)
k = opts[:num_neighbors]

if k > Nx.axis_size(data, 0) do
raise ArgumentError,
"""
expected num_neighbors to be less than or equal to \
num_samples = #{Nx.axis_size(data, 0)}, got: #{k}
"""
end

metric =
case opts[:metric] do
Expand All @@ -90,7 +99,7 @@ defmodule Scholar.Neighbors.BruteKNN do
end

%__MODULE__{
num_neighbors: opts[:num_neighbors],
num_neighbors: k,
metric: metric,
data: data,
batch_size: opts[:batch_size]
Expand Down Expand Up @@ -129,7 +138,7 @@ defmodule Scholar.Neighbors.BruteKNN do
deftransform predict(%__MODULE__{} = model, query) do
if Nx.rank(query) != 2 do
raise ArgumentError,
"expected query tensor to have shape {?, n_features},
"expected query tensor to have shape {num_queries, num_features},
got tensor with shape: #{inspect(Nx.shape(query))}"
end

Expand Down

0 comments on commit ec08bc4

Please sign in to comment.