diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md index 31c69c7ce2..bb58355b93 100644 --- a/docs/releases/unreleased.md +++ b/docs/releases/unreleased.md @@ -37,6 +37,7 @@ River's mini-batch methods now support pandas v2. In particular, River conforms ## tree - Fix a bug in `tree.splitter.NominalSplitterClassif` that generated a mismatch between the number of existing tree branches and the number of tracked branches. +- Fix a bug in `tree.ExtremelyFastDecisionTreeClassifier` where the split re-evaluation failed when the current branch's feature was not available as a split option. The fix also enables the tree to pre-prune a leaf via the tie-breaking mechanism. ## utils diff --git a/river/tree/extremely_fast_decision_tree.py b/river/tree/extremely_fast_decision_tree.py index 944505a65b..9b948e356b 100755 --- a/river/tree/extremely_fast_decision_tree.py +++ b/river/tree/extremely_fast_decision_tree.py @@ -451,59 +451,67 @@ def _reevaluate_best_split(self, node, parent, branch_index, **kwargs): # Manage memory self._enforce_size_limit() - - elif ( - x_best.merit - x_current.merit > hoeffding_bound or hoeffding_bound < self.tau - ) and (id_current != id_best): - # Create a new branch - branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split) - leaves = tuple( - self._new_leaf(initial_stats, parent=node) - for initial_stats in x_best.children_stats - ) - - new_split = x_best.assemble(branch, node.stats, node.depth, *leaves, **kwargs) - # Update weights in new_split - new_split.last_split_reevaluation_at = node.total_weight - - n_active = n_inactive = 0 - for leaf in node.iter_leaves(): - if leaf.is_active(): - n_active += 1 + elif x_current is not None: + if ( + x_best.merit - x_current.merit > hoeffding_bound + or hoeffding_bound < self.tau + ) and (id_current != id_best): + # Create a new branch + branch = self._branch_selector( + x_best.numerical_feature, x_best.multiway_split + ) + leaves = tuple( + self._new_leaf(initial_stats, parent=node) + for initial_stats in x_best.children_stats + ) + + new_split = x_best.assemble( + branch, node.stats, node.depth, *leaves, **kwargs + ) + # Update weights in new_split + new_split.last_split_reevaluation_at = node.total_weight + + n_active = n_inactive = 0 + for leaf in node.iter_leaves(): + if leaf.is_active(): + n_active += 1 + else: + n_inactive += 1 + + self._n_active_leaves -= n_active + self._n_inactive_leaves -= n_inactive + self._n_active_leaves += len(leaves) + + if parent is None: + # Root case : replace the root node by a new split node + self._root = new_split else: - n_inactive += 1 - - self._n_active_leaves -= n_active - self._n_inactive_leaves -= n_inactive - self._n_active_leaves += len(leaves) - - if parent is None: - # Root case : replace the root node by a new split node - self._root = new_split - else: - parent.children[branch_index] = new_split - - stop_flag = True - - # Manage memory - self._enforce_size_limit() - - elif ( - x_best.merit - x_current.merit > hoeffding_bound or hoeffding_bound < self.tau - ) and (id_current == id_best): - branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split) - # Change the branch but keep the existing children nodes - new_split = x_best.assemble( - branch, node.stats, node.depth, *tuple(node.children), **kwargs - ) - # Update weights in new_split - new_split.last_split_reevaluation_at = node.total_weight - - if parent is None: - # Root case : replace the root node by a new split node - self._root = new_split - else: - parent.children[branch_index] = new_split + parent.children[branch_index] = new_split + + stop_flag = True + + # Manage memory + self._enforce_size_limit() + + elif ( + x_best.merit - x_current.merit > hoeffding_bound + or hoeffding_bound < self.tau + ) and (id_current == id_best): + branch = self._branch_selector( + x_best.numerical_feature, x_best.multiway_split + ) + # Change the branch but keep the existing children nodes + new_split = x_best.assemble( + branch, node.stats, node.depth, *tuple(node.children), **kwargs + ) + # Update weights in new_split + new_split.last_split_reevaluation_at = node.total_weight + + if parent is None: + # Root case : replace the root node by a new split node + self._root = new_split + else: + parent.children[branch_index] = new_split return stop_flag @@ -551,7 +559,12 @@ def _attempt_to_split(self, node, parent, branch_index, **kwargs): node.total_weight, ) - if x_best.merit - x_null.merit > hoeffding_bound or hoeffding_bound < self.tau: + if x_best.feature is None: + # Pre-pruning - null wins + node.deactivate() + self._n_inactive_leaves += 1 + self._n_active_leaves -= 1 + elif x_best.merit - x_null.merit > hoeffding_bound or hoeffding_bound < self.tau: # Create a new branch branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split) leaves = tuple(