Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical clustering improvements #213

Open
3 tasks
josevalim opened this issue Nov 28, 2023 · 2 comments
Open
3 tasks

Hierarchical clustering improvements #213

josevalim opened this issue Nov 28, 2023 · 2 comments
Labels
help wanted Extra attention is needed

Comments

@josevalim
Copy link
Contributor

@josevalim
Copy link
Contributor Author

Here is a small patch for the first one, but unfortunately it is not enough, so perhaps something else is wrong:

diff --git a/lib/scholar/cluster/hierarchical.ex b/lib/scholar/cluster/hierarchical.ex
index 481b6cd..562367f 100644
--- a/lib/scholar/cluster/hierarchical.ex
+++ b/lib/scholar/cluster/hierarchical.ex
@@ -196,10 +196,11 @@ defmodule Scholar.Cluster.Hierarchical do
     clades = Nx.broadcast(-1, {n - 1, 2})
     sizes = Nx.broadcast(1, {2 * n - 1})
     pointers = Nx.broadcast(-1, {2 * n - 2})
+    n_sizes = Nx.broadcast(1, {n})
     diss = Nx.tensor(:infinity, type: Nx.type(pairwise)) |> Nx.broadcast({n - 1})
 
-    {{clades, diss, sizes}, _} =
-      while {{clades, diss, sizes}, {count = 0, pointers, pairwise}}, count < n - 1 do
+    {{clades, diss, sizes, n_sizes}, _} =
+      while {{clades, diss, sizes, n_sizes}, {count = 0, pointers, pairwise}}, count < n - 1 do
         # Indexes of who I am nearest to
         nearest = Nx.argmin(pairwise, axis: 1)
 
@@ -213,10 +214,21 @@ defmodule Scholar.Cluster.Hierarchical do
         # They are bidirectional but let's keep only one side.
         links = Nx.select(clades_selector and nearest > nearest_of_nearest, nearest, n)
 
-        {clades, count, pointers, pairwise, diss, sizes} =
-          merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun)
-
-        {{clades, diss, sizes}, {count, pointers, pairwise}}
+        {clades, count, pointers, pairwise, diss, sizes, n_sizes} =
+          merge_clades(
+            clades,
+            count,
+            pointers,
+            pairwise,
+            diss,
+            sizes,
+            n_sizes,
+            links,
+            n,
+            update_fun
+          )
+
+        {{clades, diss, sizes, n_sizes}, {count, pointers, pairwise}}
       end
 
     sizes = sizes[n..(2 * n - 2)]
@@ -224,16 +236,27 @@ defmodule Scholar.Cluster.Hierarchical do
     {clades[perm], diss[perm], sizes[perm]}
   end
 
-  defnp merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun) do
-    {{clades, count, pointers, pairwise, diss, sizes}, _} =
-      while {{clades, count, pointers, pairwise, diss, sizes}, links},
+  defnp merge_clades(
+          clades,
+          count,
+          pointers,
+          pairwise,
+          diss,
+          sizes,
+          n_sizes,
+          links,
+          n,
+          update_fun
+        ) do
+    {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, _} =
+      while {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, links},
             i <- 0..(Nx.size(links) - 1) do
         # i < j because of how links is formed.
         # i will become the new clade index and we "infinity-out" j.
         j = links[i]
 
         if j == n do
-          {{clades, count, pointers, pairwise, diss, sizes}, links}
+          {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, links}
         else
           # Clades a and b (i and j of pairwise) are being merged into c.
           indices = [i, j] |> Nx.stack() |> Nx.new_axis(-1)
@@ -251,6 +274,9 @@ defmodule Scholar.Cluster.Hierarchical do
           sc = sa + sb
           sizes = Nx.indexed_put(sizes, Nx.stack([i, c]) |> Nx.new_axis(-1), Nx.stack([sc, sc]))
 
+          n_sizes =
+            Nx.indexed_put(n_sizes, Nx.stack([i, j]) |> Nx.new_axis(-1), Nx.stack([sc, sc]))
+
           # Update dissimilarities
           diss = Nx.indexed_put(diss, Nx.stack([count]), pairwise[i][j])
 
@@ -259,7 +285,7 @@ defmodule Scholar.Cluster.Hierarchical do
 
           # Update pairwise
           updates =
-            update_fun.(pairwise[i], pairwise[j], pairwise[i][j], sa, sb, sc)
+            update_fun.(pairwise[i], pairwise[j], pairwise[i][j], sa, sb, n_sizes)
             |> Nx.indexed_put(indices, Nx.broadcast(:infinity, {2}))
 
           pairwise =
@@ -269,11 +295,11 @@ defmodule Scholar.Cluster.Hierarchical do
             |> Nx.put_slice([j, 0], Nx.broadcast(:infinity, {1, n}))
             |> Nx.put_slice([0, j], Nx.broadcast(:infinity, {n, 1}))
 
-          {{clades, count + 1, pointers, pairwise, diss, sizes}, links}
+          {{clades, count + 1, pointers, pairwise, diss, sizes, n_sizes}, links}
         end
       end
 
-    {clades, count, pointers, pairwise, diss, sizes}
+    {clades, count, pointers, pairwise, diss, sizes, n_sizes}
   end
 
   defnp find_clade(pointers, i) do
diff --git a/test/scholar/cluster/hierarchical_test.exs b/test/scholar/cluster/hierarchical_test.exs
index 6c4e5d5..4511252 100644
--- a/test/scholar/cluster/hierarchical_test.exs
+++ b/test/scholar/cluster/hierarchical_test.exs
@@ -127,7 +127,6 @@ defmodule Scholar.Cluster.HierarchicalTest do
       assert model.dissimilarities == Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0])
     end
 
-    @tag :skip
     test "ward", %{data: data} do
       model = Hierarchical.fit(data, linkage: :ward)

@josevalim
Copy link
Contributor Author

I have commented Ward for now, see 6845727.

@josevalim josevalim added the help wanted Extra attention is needed label Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant