Skip to content

Commit

Permalink
Fix bug in EFDT (#1347)
Browse files Browse the repository at this point in the history
  • Loading branch information
smastelini authored Oct 24, 2023
1 parent c735492 commit ba5b19d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 53 deletions.
1 change: 1 addition & 0 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ River's mini-batch methods now support pandas v2. In particular, River conforms
## tree

- Fix a bug in `tree.splitter.NominalSplitterClassif` that generated a mismatch between the number of existing tree branches and the number of tracked branches.
- Fix a bug in `tree.ExtremelyFastDecisionTreeClassifier` where the split re-evaluation failed when the current branch's feature was not available as a split option. The fix also enables the tree to pre-prune a leaf via the tie-breaking mechanism.

## utils

Expand Down
119 changes: 66 additions & 53 deletions river/tree/extremely_fast_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,59 +451,67 @@ def _reevaluate_best_split(self, node, parent, branch_index, **kwargs):

# Manage memory
self._enforce_size_limit()

elif (
x_best.merit - x_current.merit > hoeffding_bound or hoeffding_bound < self.tau
) and (id_current != id_best):
# Create a new branch
branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split)
leaves = tuple(
self._new_leaf(initial_stats, parent=node)
for initial_stats in x_best.children_stats
)

new_split = x_best.assemble(branch, node.stats, node.depth, *leaves, **kwargs)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

n_active = n_inactive = 0
for leaf in node.iter_leaves():
if leaf.is_active():
n_active += 1
elif x_current is not None:
if (
x_best.merit - x_current.merit > hoeffding_bound
or hoeffding_bound < self.tau
) and (id_current != id_best):
# Create a new branch
branch = self._branch_selector(
x_best.numerical_feature, x_best.multiway_split
)
leaves = tuple(
self._new_leaf(initial_stats, parent=node)
for initial_stats in x_best.children_stats
)

new_split = x_best.assemble(
branch, node.stats, node.depth, *leaves, **kwargs
)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

n_active = n_inactive = 0
for leaf in node.iter_leaves():
if leaf.is_active():
n_active += 1
else:
n_inactive += 1

self._n_active_leaves -= n_active
self._n_inactive_leaves -= n_inactive
self._n_active_leaves += len(leaves)

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
n_inactive += 1

self._n_active_leaves -= n_active
self._n_inactive_leaves -= n_inactive
self._n_active_leaves += len(leaves)

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
parent.children[branch_index] = new_split

stop_flag = True

# Manage memory
self._enforce_size_limit()

elif (
x_best.merit - x_current.merit > hoeffding_bound or hoeffding_bound < self.tau
) and (id_current == id_best):
branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split)
# Change the branch but keep the existing children nodes
new_split = x_best.assemble(
branch, node.stats, node.depth, *tuple(node.children), **kwargs
)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
parent.children[branch_index] = new_split
parent.children[branch_index] = new_split

stop_flag = True

# Manage memory
self._enforce_size_limit()

elif (
x_best.merit - x_current.merit > hoeffding_bound
or hoeffding_bound < self.tau
) and (id_current == id_best):
branch = self._branch_selector(
x_best.numerical_feature, x_best.multiway_split
)
# Change the branch but keep the existing children nodes
new_split = x_best.assemble(
branch, node.stats, node.depth, *tuple(node.children), **kwargs
)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
parent.children[branch_index] = new_split

return stop_flag

Expand Down Expand Up @@ -551,7 +559,12 @@ def _attempt_to_split(self, node, parent, branch_index, **kwargs):
node.total_weight,
)

if x_best.merit - x_null.merit > hoeffding_bound or hoeffding_bound < self.tau:
if x_best.feature is None:
# Pre-pruning - null wins
node.deactivate()
self._n_inactive_leaves += 1
self._n_active_leaves -= 1
elif x_best.merit - x_null.merit > hoeffding_bound or hoeffding_bound < self.tau:
# Create a new branch
branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split)
leaves = tuple(
Expand Down

0 comments on commit ba5b19d

Please sign in to comment.