diff --git a/river/sketch/hierarchical_heavy_hitters.py b/river/sketch/hierarchical_heavy_hitters.py index 4c936f78c2..fa32f7daff 100644 --- a/river/sketch/hierarchical_heavy_hitters.py +++ b/river/sketch/hierarchical_heavy_hitters.py @@ -189,7 +189,7 @@ def compress(self): def _compress_node(self, node: HierarchicalHeavyHitters.Node): """Recursively compress nodes in the hierarchical tree.""" - if not node.children: + if node is not None and not node.children: return for child_key, child_node in list(node.children.items()): @@ -247,26 +247,31 @@ def __getitem__(self, key: typing.Hashable) -> int: """Get the count of a specific hierarchical key.""" current = self.root - for i in range(len(key)): - + if isinstance(key, str): + for i in range(len(key)): sub_key = key[:i + 1] if sub_key not in current.children: - return 0 - + current = current.children[sub_key] if sub_key == key: - - return current.ge - + return current.ge + else: + + return 0 + return 0 def totals(self) -> int: """Return the total number of elements in the hierarchical tree.""" - return self._count_entries(self.root) -1 + if self.root is not None: + total = self._count_entries(self.root) - 1 + else: + total = 0 + return total def _count_entries(self, node: HierarchicalHeavyHitters.Node) -> int: """Recursively count the total number of nodes in the hierarchical tree.""" diff --git a/river/sketch/hyper_log_log.py b/river/sketch/hyper_log_log.py index 0ec28365ae..18281aad1b 100644 --- a/river/sketch/hyper_log_log.py +++ b/river/sketch/hyper_log_log.py @@ -114,7 +114,7 @@ def update(self, x: typing.Hashable): j = hash_val & (self.m - 1) w = hash_val >> self.b self.registers[j] = max(self.registers[j], self.left_most_one(w)) - return + return None def count(self) -> int: """ diff --git a/river/sketch/space_saving.py b/river/sketch/space_saving.py index 88de70cef0..5afef82c1d 100644 --- a/river/sketch/space_saving.py +++ b/river/sketch/space_saving.py @@ -72,13 +72,12 @@ def __init__(self, k: int): def update(self, x: typing.Hashable, w: int = 1): """Update the counts with the given element.""" + if x in self.counts: self.counts[x] += w - elif len(self.counts) >= self.k: - min_count_key = min(self.counts, key=self.counts.get) + min_count_key = min(self.counts, key=lambda k: self.counts[k]) # Use lambda to specify key function self.counts[x] = self.counts.pop(min_count_key, 0) + 1 - else: self.counts[x] = w