Skip to content

Commit

Permalink
Clear out MyPy's warnings
Browse files Browse the repository at this point in the history
Type annotations are added to some functions; this allows to remove
MyPy's warnings when typechecking the code by clearing some
inconsistencies.

A couple incorrect or extraneous annotations were removed in the
process.
  • Loading branch information
e10e3 committed Jul 23, 2024
1 parent 9d05b77 commit d7dace7
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 20 deletions.
6 changes: 3 additions & 3 deletions river/base/drift_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class _BaseDriftDetector(base.Base):
def __init__(self):
self._drift_detected = False

def _reset(self):
def _reset(self) -> None:
"""Reset the detector's state."""
self._drift_detected = False

@property
def drift_detected(self):
def drift_detected(self) -> bool:
"""Whether or not a drift is detected following the last update."""
return self._drift_detected

Expand Down Expand Up @@ -57,7 +57,7 @@ class DriftDetector(_BaseDriftDetector):
"""A drift detector."""

@abc.abstractmethod
def update(self, x: int | float) -> DriftDetector:
def update(self, x: int | float) -> None:
"""Update the detector with a single data point.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion river/drift/kswin.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _reset(self):
super()._reset()
self.p_value = 0
self.n = 0
self.window: collections.deque = collections.deque(maxlen=self.window_size)
self.window = collections.deque(maxlen=self.window_size)
self._rng = random.Random(self.seed)

def update(self, x):
Expand Down
4 changes: 2 additions & 2 deletions river/forest/aggregated_mondrian_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
# memory of the classes
self._classes: set[base.typing.ClfTarget] = set()

def _initialize_trees(self):
def _initialize_trees(self) -> None:
self.data: list[MondrianTreeClassifier] = []
for _ in range(self.n_estimators):
tree = MondrianTreeClassifier(
Expand Down Expand Up @@ -287,7 +287,7 @@ def __init__(

self.iteration = 0

def _initialize_trees(self):
def _initialize_trees(self) -> None:
"""Initialize the forest."""

self.data: list[MondrianTreeRegressor] = []
Expand Down
3 changes: 2 additions & 1 deletion river/tree/nodes/hatc_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math

from river import base
from river import stats as st
from river.utils.norm import normalize_values_in_dict
from river.utils.random import poisson
Expand Down Expand Up @@ -133,7 +134,7 @@ class AdaBranchClassifier(DTBranch):
Other parameters passed to the split node.
"""

def __init__(self, stats, *children, drift_detector, **attributes):
def __init__(self, stats: dict, *children, drift_detector: base.DriftDetector, **attributes):
super().__init__(stats, *children, **attributes)
self.drift_detector = drift_detector
self._alternate_tree = None
Expand Down
4 changes: 2 additions & 2 deletions river/tree/nodes/sgt_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SGTLeaf(Leaf):
Parameters passed to the feature quantizers.
"""

def __init__(self, prediction=0.0, depth=0, split_params=None):
def __init__(self, prediction: float = 0.0, depth: int = 0, split_params: dict | None = None):
super().__init__()
self._prediction = prediction
self.depth = depth
Expand All @@ -52,7 +52,7 @@ def reset(self):
self._update_stats = GradHessStats()

@staticmethod
def is_categorical(idx, x_val, nominal_attributes):
def is_categorical(idx: str, x_val, nominal_attributes: list[str]) -> bool:
return not isinstance(x_val, numbers.Number) or idx in nominal_attributes

def update(self, x: dict, gh: GradHess, sgt, w: float = 1.0):
Expand Down
22 changes: 11 additions & 11 deletions river/tree/stochastic_gradient_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from river import base, tree

from .losses import BinaryCrossEntropyLoss, SquaredErrorLoss
from .losses import BinaryCrossEntropyLoss, Loss, SquaredErrorLoss
from .nodes.branch import DTBranch, NominalMultiwayBranch, NumericBinaryBranch
from .nodes.sgt_nodes import SGTLeaf
from .utils import BranchFactory, GradHessMerit
Expand All @@ -23,15 +23,15 @@ class StochasticGradientTree(base.Estimator, abc.ABC):

def __init__(
self,
loss_func,
delta,
grace_period,
init_pred,
max_depth,
lambda_value,
gamma,
nominal_attributes,
feature_quantizer,
loss_func: Loss,
delta: float,
grace_period: int,
init_pred: float,
max_depth: int | None,
lambda_value: float,
gamma: float,
nominal_attributes: list[str] | None,
feature_quantizer: tree.splitter.Quantizer | None,
):
# What really defines how a SGT works is its loss function
self.loss_func = loss_func
Expand All @@ -56,7 +56,7 @@ def __init__(
self._root: SGTLeaf | DTBranch = SGTLeaf(prediction=self.init_pred)

# set used to check whether categorical feature has been already split
self._split_features = set()
self._split_features: set[str] = set()
self._n_splits = 0
self._n_node_updates = 0
self._n_observations = 0
Expand Down

0 comments on commit d7dace7

Please sign in to comment.