From ec08bc4e9a195d1f8069df83db3d72036bba3622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Mon, 15 Apr 2024 13:28:35 +0200 Subject: [PATCH] raise error when k > n --- lib/scholar/neighbors/brute_knn.ex | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 4fae09b9..bfb584af 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -71,11 +71,20 @@ defmodule Scholar.Neighbors.BruteKNN do deftransform fit(data, opts) do if Nx.rank(data) != 2 do raise ArgumentError, - "expected input tensor to have shape {n_samples, n_features}, + "expected input tensor to have shape {num_samples, num_features}, got tensor with shape: #{inspect(Nx.shape(data))}" end opts = NimbleOptions.validate!(opts, @opts_schema) + k = opts[:num_neighbors] + + if k > Nx.axis_size(data, 0) do + raise ArgumentError, + """ + expected num_neighbors to be less than or equal to \ + num_samples = #{Nx.axis_size(data, 0)}, got: #{k} + """ + end metric = case opts[:metric] do @@ -90,7 +99,7 @@ defmodule Scholar.Neighbors.BruteKNN do end %__MODULE__{ - num_neighbors: opts[:num_neighbors], + num_neighbors: k, metric: metric, data: data, batch_size: opts[:batch_size] @@ -129,7 +138,7 @@ defmodule Scholar.Neighbors.BruteKNN do deftransform predict(%__MODULE__{} = model, query) do if Nx.rank(query) != 2 do raise ArgumentError, - "expected query tensor to have shape {?, n_features}, + "expected query tensor to have shape {num_queries, num_features}, got tensor with shape: #{inspect(Nx.shape(query))}" end