Skip to content

Commit

Permalink
Fix issue #1507 (SWINN streamline) (#1508)
Browse files Browse the repository at this point in the history
* streamline code

* now I remember why those methods exist

* change test

* release notes
  • Loading branch information
smastelini authored Mar 1, 2024
1 parent 1e6ded0 commit f603161
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 127 deletions.
10 changes: 8 additions & 2 deletions docs/unreleased.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Unreleased

## drift

- Added `FHDDM` drift detector.
- Added a `iter_polars` function to iterate over the rows of a polars DataFrame.
- Added `FHDDM` drift detector.
- Added a `iter_polars` function to iterate over the rows of a polars DataFrame.

## neighbors

- Simplified `neighbors.SWINN` to avoid recursion limit and pickling issues.
102 changes: 48 additions & 54 deletions river/neighbors/ann/nn_vertex.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from __future__ import annotations

import heapq
import itertools
import math
import random

from river import base


class Vertex(base.Base):
_isolated: set[Vertex] = set()
_isolated: set[int] = set()

def __init__(self, item, uuid: int) -> None:
self.item = item
self.uuid = uuid
self.edges: dict[Vertex, float] = {}
self.r_edges: dict[Vertex, float] = {}
self.flags: set[Vertex] = set()
self.worst_edge: Vertex | None = None

def __hash__(self) -> int:
return self.uuid
self.edges: dict[int, float] = {}
self.r_edges: dict[int, float] = {}
self.flags: set[int]= set()
self.worst_edge: int | None = None

def __eq__(self, other) -> bool:
if not isinstance(other, Vertex):
Expand All @@ -34,69 +30,67 @@ def __lt__(self, other) -> bool:

return self.uuid < other.uuid

def farewell(self):
def farewell(self, vertex_pool: list[Vertex]):
for rn in list(self.r_edges):
rn.rem_edge(self)
vertex_pool[rn].rem_edge(self)

for n in list(self.edges):
self.rem_edge(n)

self.flags = None
self.worst_edge = None
self.rem_edge(vertex_pool[n])

Vertex._isolated.discard(self)
Vertex._isolated.discard(self.uuid)

def fill(self, neighbors: list[Vertex], dists: list[float]):
for n, dist in zip(neighbors, dists):
self.edges[n] = dist
self.flags.add(n)
n.r_edges[self] = dist
self.edges[n.uuid] = dist
self.flags.add(n.uuid)
n.r_edges[self.uuid] = dist

# Neighbors are ordered by distance
self.worst_edge = n
# Neighbors are ordered by distance, so the last neighbor
# is the farthest one
self.worst_edge = n.uuid

def add_edge(self, vertex: Vertex, dist):
self.edges[vertex] = dist
self.flags.add(vertex)
vertex.r_edges[self] = dist
self.edges[vertex.uuid] = dist
self.flags.add(vertex.uuid)
vertex.r_edges[self.uuid] = dist

if self.worst_edge is None or self.edges[self.worst_edge] < dist:
self.worst_edge = vertex
self.worst_edge = vertex.uuid

def rem_edge(self, vertex: Vertex):
self.edges.pop(vertex)
vertex.r_edges.pop(self)
self.flags.discard(vertex)
self.edges.pop(vertex.uuid)
vertex.r_edges.pop(self.uuid)
self.flags.discard(vertex.uuid)

if self.has_neighbors():
if vertex == self.worst_edge:
self.worst_edge = max(self.edges, key=self.edges.get) # type: ignore
if vertex.uuid == self.worst_edge:
self.worst_edge = max(self.edges, key=self.edges.__getitem__)
else:
self.worst_edge = None

if not self.has_rneighbors():
Vertex._isolated.add(self)
Vertex._isolated.add(self.uuid)

def push_edge(self, node: Vertex, dist: float, max_edges: int) -> int:
if self.is_neighbor(node) or node == self:
def push_edge(self, node: Vertex, dist: float, max_edges: int, vertex_pool: list[Vertex]) -> int:
if self.is_neighbor(node) or node.uuid == self.uuid:
return 0

if len(self.edges) >= max_edges:
if self.worst_edge is None or self.edges.get(self.worst_edge, math.inf) <= dist:
return 0
self.rem_edge(self.worst_edge)
self.rem_edge(vertex_pool[self.worst_edge])

self.add_edge(node, dist)

return 1

def is_neighbor(self, vertex):
return vertex in self.edges or vertex in self.r_edges
def is_neighbor(self, vertex: Vertex):
return vertex.uuid in self.edges or vertex.uuid in self.r_edges

def get_edge(self, vertex: Vertex):
if vertex in self.edges:
return self, vertex, self.edges[vertex]
return vertex, self, self.r_edges[vertex]
if vertex.uuid in self.edges:
return self, vertex, self.edges[vertex.uuid]
return vertex, self, self.r_edges[vertex.uuid]

def has_neighbors(self) -> bool:
return len(self.edges) > 0
Expand All @@ -112,50 +106,50 @@ def sample_flags(self):
def sample_flags(self, sampled):
self.flags -= set(sampled)

def neighbors(self) -> tuple[list[Vertex], list[float]]:
def neighbors(self) -> tuple[list[int], list[float]]:
res = tuple(map(list, zip(*((node, dist) for node, dist in self.edges.items()))))
return res if len(res) > 0 else ([], []) # type: ignore

def r_neighbors(self) -> tuple[list[Vertex], list[float]]:
def r_neighbors(self) -> tuple[list[int], list[float]]:
res = tuple(map(list, zip(*((vertex, dist) for vertex, dist in self.r_edges.items()))))
return res if len(res) > 0 else ([], []) # type: ignore

def all_neighbors(self):
def all_neighbors(self) -> set[int]:
return set.union(set(self.edges.keys()), set(self.r_edges.keys()))

def is_isolated(self):
return len(self.edges) == 0 and len(self.r_edges) == 0

def prune(self, prune_prob: float, prune_trigger: int, rng: random.Random):
def prune(self, prune_prob: float, prune_trigger: int, vertex_pool: list[Vertex], rng: random.Random):
if prune_prob == 0:
return

total_degree = len(self.edges) + len(self.r_edges)
if total_degree <= prune_trigger:
return

# To avoid tie in distances
counter = itertools.count()
edge_pool: list[tuple[float, int, Vertex, bool]] = []
edge_pool: list[tuple[float, int, bool]] = []
for n, dist in self.edges.items():
heapq.heappush(edge_pool, (dist, next(counter), n, True))
heapq.heappush(edge_pool, (dist, n, True))

for rn, dist in self.r_edges.items():
heapq.heappush(edge_pool, (dist, next(counter), rn, False))
heapq.heappush(edge_pool, (dist, rn, False))

# Start with the best undirected edge
selected: list[Vertex] = [heapq.heappop(edge_pool)[2]]
selected: list[int] = [heapq.heappop(edge_pool)[1]]
while len(edge_pool) > 0:
c_dist, _, c, c_isdir = heapq.heappop(edge_pool)
c_dist, c, c_isdir = heapq.heappop(edge_pool)
discarded = False
for s in selected:
if s.is_neighbor(c) and rng.random() < prune_prob:
orig, dest, dist = s.get_edge(c)
s_v = vertex_pool[s]
c_v = vertex_pool[c]
if s_v.is_neighbor(c_v) and rng.random() < prune_prob:
orig, dest, dist = s_v.get_edge(c_v)
if dist < c_dist:
if c_isdir:
self.rem_edge(c)
self.rem_edge(c_v)
else:
c.rem_edge(self)
c_v.rem_edge(self)
discarded = True
break
else:
Expand Down
Loading

0 comments on commit f603161

Please sign in to comment.