From f603161966370cad27d7ee76780b6026b2a07563 Mon Sep 17 00:00:00 2001 From: Saulo Martiello Mastelini Date: Fri, 1 Mar 2024 18:17:57 -0300 Subject: [PATCH] Fix issue #1507 (SWINN streamline) (#1508) * streamline code * now I remember why those methods exist * change test * release notes --- docs/unreleased.md | 10 ++- river/neighbors/ann/nn_vertex.py | 102 ++++++++++++------------- river/neighbors/ann/swinn.py | 121 +++++++++++++++--------------- river/neighbors/ann/utils.py | 10 --- river/neighbors/knn_classifier.py | 2 +- 5 files changed, 118 insertions(+), 127 deletions(-) delete mode 100644 river/neighbors/ann/utils.py diff --git a/docs/unreleased.md b/docs/unreleased.md index b2c83d0050..743c1a54f7 100644 --- a/docs/unreleased.md +++ b/docs/unreleased.md @@ -1,4 +1,10 @@ +# Unreleased + ## drift -- Added `FHDDM` drift detector. -- Added a `iter_polars` function to iterate over the rows of a polars DataFrame. \ No newline at end of file +- Added `FHDDM` drift detector. +- Added a `iter_polars` function to iterate over the rows of a polars DataFrame. + +## neighbors + +- Simplified `neighbors.SWINN` to avoid recursion limit and pickling issues. diff --git a/river/neighbors/ann/nn_vertex.py b/river/neighbors/ann/nn_vertex.py index e1f85905c1..c17ea49c52 100644 --- a/river/neighbors/ann/nn_vertex.py +++ b/river/neighbors/ann/nn_vertex.py @@ -1,7 +1,6 @@ from __future__ import annotations import heapq -import itertools import math import random @@ -9,18 +8,15 @@ class Vertex(base.Base): - _isolated: set[Vertex] = set() + _isolated: set[int] = set() def __init__(self, item, uuid: int) -> None: self.item = item self.uuid = uuid - self.edges: dict[Vertex, float] = {} - self.r_edges: dict[Vertex, float] = {} - self.flags: set[Vertex] = set() - self.worst_edge: Vertex | None = None - - def __hash__(self) -> int: - return self.uuid + self.edges: dict[int, float] = {} + self.r_edges: dict[int, float] = {} + self.flags: set[int]= set() + self.worst_edge: int | None = None def __eq__(self, other) -> bool: if not isinstance(other, Vertex): @@ -34,69 +30,67 @@ def __lt__(self, other) -> bool: return self.uuid < other.uuid - def farewell(self): + def farewell(self, vertex_pool: list[Vertex]): for rn in list(self.r_edges): - rn.rem_edge(self) + vertex_pool[rn].rem_edge(self) for n in list(self.edges): - self.rem_edge(n) - - self.flags = None - self.worst_edge = None + self.rem_edge(vertex_pool[n]) - Vertex._isolated.discard(self) + Vertex._isolated.discard(self.uuid) def fill(self, neighbors: list[Vertex], dists: list[float]): for n, dist in zip(neighbors, dists): - self.edges[n] = dist - self.flags.add(n) - n.r_edges[self] = dist + self.edges[n.uuid] = dist + self.flags.add(n.uuid) + n.r_edges[self.uuid] = dist - # Neighbors are ordered by distance - self.worst_edge = n + # Neighbors are ordered by distance, so the last neighbor + # is the farthest one + self.worst_edge = n.uuid def add_edge(self, vertex: Vertex, dist): - self.edges[vertex] = dist - self.flags.add(vertex) - vertex.r_edges[self] = dist + self.edges[vertex.uuid] = dist + self.flags.add(vertex.uuid) + vertex.r_edges[self.uuid] = dist if self.worst_edge is None or self.edges[self.worst_edge] < dist: - self.worst_edge = vertex + self.worst_edge = vertex.uuid def rem_edge(self, vertex: Vertex): - self.edges.pop(vertex) - vertex.r_edges.pop(self) - self.flags.discard(vertex) + self.edges.pop(vertex.uuid) + vertex.r_edges.pop(self.uuid) + self.flags.discard(vertex.uuid) if self.has_neighbors(): - if vertex == self.worst_edge: - self.worst_edge = max(self.edges, key=self.edges.get) # type: ignore + if vertex.uuid == self.worst_edge: + self.worst_edge = max(self.edges, key=self.edges.__getitem__) else: self.worst_edge = None if not self.has_rneighbors(): - Vertex._isolated.add(self) + Vertex._isolated.add(self.uuid) - def push_edge(self, node: Vertex, dist: float, max_edges: int) -> int: - if self.is_neighbor(node) or node == self: + def push_edge(self, node: Vertex, dist: float, max_edges: int, vertex_pool: list[Vertex]) -> int: + if self.is_neighbor(node) or node.uuid == self.uuid: return 0 if len(self.edges) >= max_edges: if self.worst_edge is None or self.edges.get(self.worst_edge, math.inf) <= dist: return 0 - self.rem_edge(self.worst_edge) + self.rem_edge(vertex_pool[self.worst_edge]) self.add_edge(node, dist) return 1 - def is_neighbor(self, vertex): - return vertex in self.edges or vertex in self.r_edges + def is_neighbor(self, vertex: Vertex): + return vertex.uuid in self.edges or vertex.uuid in self.r_edges def get_edge(self, vertex: Vertex): - if vertex in self.edges: - return self, vertex, self.edges[vertex] - return vertex, self, self.r_edges[vertex] + if vertex.uuid in self.edges: + return self, vertex, self.edges[vertex.uuid] + return vertex, self, self.r_edges[vertex.uuid] def has_neighbors(self) -> bool: return len(self.edges) > 0 @@ -112,21 +106,21 @@ def sample_flags(self): def sample_flags(self, sampled): self.flags -= set(sampled) - def neighbors(self) -> tuple[list[Vertex], list[float]]: + def neighbors(self) -> tuple[list[int], list[float]]: res = tuple(map(list, zip(*((node, dist) for node, dist in self.edges.items())))) return res if len(res) > 0 else ([], []) # type: ignore - def r_neighbors(self) -> tuple[list[Vertex], list[float]]: + def r_neighbors(self) -> tuple[list[int], list[float]]: res = tuple(map(list, zip(*((vertex, dist) for vertex, dist in self.r_edges.items())))) return res if len(res) > 0 else ([], []) # type: ignore - def all_neighbors(self): + def all_neighbors(self) -> set[int]: return set.union(set(self.edges.keys()), set(self.r_edges.keys())) def is_isolated(self): return len(self.edges) == 0 and len(self.r_edges) == 0 - def prune(self, prune_prob: float, prune_trigger: int, rng: random.Random): + def prune(self, prune_prob: float, prune_trigger: int, vertex_pool: list[Vertex], rng: random.Random): if prune_prob == 0: return @@ -134,28 +128,28 @@ def prune(self, prune_prob: float, prune_trigger: int, rng: random.Random): if total_degree <= prune_trigger: return - # To avoid tie in distances - counter = itertools.count() - edge_pool: list[tuple[float, int, Vertex, bool]] = [] + edge_pool: list[tuple[float, int, bool]] = [] for n, dist in self.edges.items(): - heapq.heappush(edge_pool, (dist, next(counter), n, True)) + heapq.heappush(edge_pool, (dist, n, True)) for rn, dist in self.r_edges.items(): - heapq.heappush(edge_pool, (dist, next(counter), rn, False)) + heapq.heappush(edge_pool, (dist, rn, False)) # Start with the best undirected edge - selected: list[Vertex] = [heapq.heappop(edge_pool)[2]] + selected: list[int] = [heapq.heappop(edge_pool)[1]] while len(edge_pool) > 0: - c_dist, _, c, c_isdir = heapq.heappop(edge_pool) + c_dist, c, c_isdir = heapq.heappop(edge_pool) discarded = False for s in selected: - if s.is_neighbor(c) and rng.random() < prune_prob: - orig, dest, dist = s.get_edge(c) + s_v = vertex_pool[s] + c_v = vertex_pool[c] + if s_v.is_neighbor(c_v) and rng.random() < prune_prob: + orig, dest, dist = s_v.get_edge(c_v) if dist < c_dist: if c_isdir: - self.rem_edge(c) + self.rem_edge(c_v) else: - c.rem_edge(self) + c_v.rem_edge(self) discarded = True break else: diff --git a/river/neighbors/ann/swinn.py b/river/neighbors/ann/swinn.py index 35419c8812..5d71b0e2c7 100644 --- a/river/neighbors/ann/swinn.py +++ b/river/neighbors/ann/swinn.py @@ -117,7 +117,7 @@ def __init__( self.n_iters = n_iters self.seed = seed - self._data: collections.deque[Vertex] = collections.deque(maxlen=self.maxlen) + self._data: collections.deque[Vertex | None] = collections.deque(maxlen=self.maxlen) self._uuid = itertools.cycle(range(self.maxlen)) self._rng = random.Random(self.seed) self._index = False @@ -146,16 +146,16 @@ def _init_graph(self): def _fix_graph(self): """Connect every isolated node in the graph to their nearest neighbors.""" - for node in list(Vertex._isolated): - if not node.is_isolated(): + for nid in list(Vertex._isolated): + if not self[nid].is_isolated(): continue - neighbors, dists = self._search(node.item, self.graph_k) - node.fill(neighbors, dists) + neighbors, dists = self._search(self[nid].item, self.graph_k) + self[nid].fill(neighbors, dists) # Update class property Vertex._isolated.clear() - def _safe_node_removal(self): + def _safe_node_removal(self, nid: int): """Remove the oldest data point from the search graph. Make sure nodes are accessible from any given starting point after removing the oldest @@ -163,24 +163,24 @@ def _safe_node_removal(self): the only bridge between its neighbors. """ - node = self._data.popleft() + node = self[nid] # Get previous neighborhood info rns = node.r_neighbors()[0] ns = node.neighbors()[0] - node.farewell() + node.farewell(vertex_pool=self._data) # Nodes whose only direct neighbor was the removed node - rns = {rn for rn in rns if not rn.has_neighbors()} + rns = {rn for rn in rns if not self[rn].has_neighbors()} # Nodes whose only reverse neighbor was the removed node - ns = {n for n in ns if not n.has_rneighbors()} + ns = {n for n in ns if not self[n].has_rneighbors()} affected = list(rns | ns) isolated = rns.intersection(ns) # First we handle the unreachable nodes for al in isolated: - neighbors, dists = self._search(al.item, self.graph_k) - al.fill(neighbors, dists) + neighbors, dists = self._search(self[al].item, self.graph_k) + self[al].fill(neighbors, dists) rns -= isolated ns -= isolated @@ -192,26 +192,29 @@ def _safe_node_removal(self): # Check the group of nodes without reverse neighborhood for seeds # Thus we can join two separate groups if len(ns) > 0: - seed = self._rng.choice(ns) + seed = self[self._rng.choice(ns)] # Use the search index to create new connections - neighbors, dists = self._search(rn.item, self.graph_k, seed=seed, exclude=rn) - rn.fill(neighbors, dists) + neighbors, dists = self._search(self[rn].item, self.graph_k, seed=seed, exclude={rn}) + self[rn].fill(neighbors, dists) + + self._data[nid] = None + del node self._refine(affected) - def _refine(self, nodes: list[Vertex] = None): + def _refine(self, nodes: list[int] = None): """Update the nearest neighbor graph to improve the edge distances. Parameters ---------- nodes - The list of nodes for which the neighborhood refinement will be applied. + The list of node ids for which the neighborhood refinement will be applied. If `None`, all nodes will have their neighborhood enhanced. """ if nodes is None: - nodes = [n for n in self] + nodes = [n.uuid for n in self] min_changes = self.delta * self.graph_k * len(nodes) @@ -223,65 +226,62 @@ def _refine(self, nodes: list[Vertex] = None): old = collections.defaultdict(set) # Expand undirected neighborhood - for node in nodes: + for nid in nodes: + node = self[nid] neighbors = node.neighbors()[0] flags = node.sample_flags for neigh, flag in zip(neighbors, flags): # To avoid evaluating previous neighbors again - tried.add((node.uuid, neigh.uuid)) + tried.add((nid, neigh)) if flag: - new[node].add(neigh) - new[neigh].add(node) + new[nid].add(neigh) + new[neigh].add(nid) else: - old[node].add(neigh) - old[neigh].add(node) + old[nid].add(neigh) + old[neigh].add(nid) # Limits the maximum number of edges to explore and update sample flags - for node in nodes: - if len(new[node]) > self.max_candidates: - new[node] = self._rng.sample(tuple(new[node]), self.max_candidates) # type: ignore - else: - new[node] = new[node] + for nid in nodes: + if len(new[nid]) > self.max_candidates: + new[nid] = self._rng.sample(tuple(new[nid]), self.max_candidates) # type: ignore - if len(old[node]) > self.max_candidates: - old[node] = self._rng.sample(tuple(old[node]), self.max_candidates) # type: ignore - else: - old[node] = old[node] + if len(old[nid]) > self.max_candidates: + old[nid] = self._rng.sample(tuple(old[nid]), self.max_candidates) # type: ignore - node.sample_flags = new[node] + self[nid].sample_flags = new[nid] # Perform local joins an attempt to improve the neighborhood - for node in nodes: + for nid in nodes: # The origin of the join must have a boolean flag set to true - for n1 in new[node]: + for n1 in new[nid]: # Consider connections between vertices whose boolean flags are both true - for n2 in new[node]: - if n1.uuid == n2.uuid or n1.is_neighbor(n2): + for n2 in new[nid]: + if n1 == n2 or self[n1].is_neighbor(self[n2]): continue - if (n1.uuid, n2.uuid) in tried or (n2.uuid, n1.uuid) in tried: + if (n1, n2) in tried or (n2, n1) in tried: continue - dist = self.dist_func(n1.item, n2.item) - total_changes += n1.push_edge(n2, dist, self.graph_k) - total_changes += n2.push_edge(n1, dist, self.graph_k) + dist = self.dist_func(self[n1].item, self[n2].item) + total_changes += self[n1].push_edge(self[n2], dist, self.graph_k, self._data) + total_changes += self[n2].push_edge(self[n1], dist, self.graph_k, self._data) - tried.add((n1.uuid, n2.uuid)) + tried.add((n1, n2)) # Or one of the connections has a boolean flag set to false - for n2 in old[node]: - if n1.uuid == n2.uuid or n1.is_neighbor(n2): + for n2 in old[nid]: + if n1 == n2 or self[n1].is_neighbor(self[n2]): continue - if (n1.uuid, n2.uuid) in tried or (n2.uuid, n1.uuid) in tried: + if (n1, n2) in tried or (n2, n1) in tried: continue - dist = self.dist_func(n1.item, n2.item) - total_changes += n1.push_edge(n2, dist, self.graph_k) - total_changes += n2.push_edge(n1, dist, self.graph_k) + dist = self.dist_func(self[n1].item, self[n2].item) + total_changes += self[n1].push_edge(self[n2], dist, self.graph_k, self._data) + total_changes += self[n2].push_edge(self[n1], dist, self.graph_k, self._data) - tried.add((n1.uuid, n2.uuid)) + tried.add((n1, n2)) # Stopping criterion if total_changes <= min_changes: @@ -289,7 +289,7 @@ def _refine(self, nodes: list[Vertex] = None): # Reduce the number of edges, if needed for n in nodes: - n.prune(self.prune_prob, self.max_candidates, self._rng) + self[n].prune(self.prune_prob, self.max_candidates, self._data, self._rng) # Ensure that no node is isolated in the graph self._fix_graph() @@ -322,13 +322,16 @@ def append(self, item: typing.Any, **kwargs): # A slot will be replaced, so let's update the search graph first if len(self) == self.maxlen: - self._safe_node_removal() + self._safe_node_removal(node.uuid) # Assign the closest neighbors to the new item neighbors, dists = self._search(node.item, self.graph_k) # Add the new element to the buffer - self._data.append(node) + if len(self) == self.maxlen: + self._data[node.uuid] = node + else: + self._data.append(node) node.fill(neighbors, dists) def _linear_scan(self, item, k): @@ -340,7 +343,7 @@ def _linear_scan(self, item, k): return None - def _search(self, item, k, epsilon: float = 0.1, seed=None, exclude=None) -> tuple[list, list]: + def _search(self, item, k, epsilon: float = 0.1, seed: Vertex = None, exclude: set[int] | None = None) -> tuple[list, list]: # Limiter for the distance bound distance_scale = 1 + epsilon # Distance threshold for early stops @@ -348,15 +351,13 @@ def _search(self, item, k, epsilon: float = 0.1, seed=None, exclude=None) -> tup if exclude is None: exclude = set() - else: - exclude = {exclude.uuid} if seed is None: # Make sure the starting point for the search is valid while True: # Random seed point to start the search seed = self[self._rng.randint(0, len(self) - 1)] - if not seed.is_isolated() and seed.uuid not in exclude: + if seed is not None and not seed.is_isolated() and seed.uuid not in exclude: break dist = self.dist_func(item, seed.item) @@ -373,7 +374,7 @@ def _search(self, item, k, epsilon: float = 0.1, seed=None, exclude=None) -> tup c_dist, c_n = heapq.heappop(pool) while c_dist < distance_bound: - tns = [n for n in c_n.all_neighbors() if n.uuid not in visited] + tns = [self[n] for n in c_n.all_neighbors() if n not in visited] for n in tns: dist = self.dist_func(item, n.item) @@ -454,9 +455,9 @@ def connectivity(self) -> list[int]: """ forest = set() - trees = {n: {n} for n in self} + trees = {n.uuid: {n.uuid} for n in self} - edges = [((n1, n2), w) for n1 in self for n2, w in n1.edges.items()] + edges = [((n1.uuid, n2), w) for n1 in self for n2, w in n1.edges.items()] edges.sort(key=operator.itemgetter(1)) for (n1, n2), _ in edges: diff --git a/river/neighbors/ann/utils.py b/river/neighbors/ann/utils.py deleted file mode 100644 index 7d33f4f3b6..0000000000 --- a/river/neighbors/ann/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - - -def argsort(dists): - return sorted(range(len(dists)), key=dists.__getitem__) - - -def rem_duplicates(pool): - seen = set() - return [n for n in pool if not (n in seen or seen.add(n))] diff --git a/river/neighbors/knn_classifier.py b/river/neighbors/knn_classifier.py index 67264abf3d..b8da74b437 100644 --- a/river/neighbors/knn_classifier.py +++ b/river/neighbors/knn_classifier.py @@ -68,7 +68,7 @@ class KNNClassifier(base.Classifier): ... ) >>> evaluate.progressive_val_score(dataset, model, metrics.Accuracy()) - Accuracy: 89.67% + Accuracy: 89.59% """