Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EFDT #1347

Merged
merged 1 commit into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ River's mini-batch methods now support pandas v2. In particular, River conforms
## tree

- Fix a bug in `tree.splitter.NominalSplitterClassif` that generated a mismatch between the number of existing tree branches and the number of tracked branches.
- Fix a bug in `tree.ExtremelyFastDecisionTreeClassifier` where the split re-evaluation failed when the current branch's feature was not available as a split option. The fix also enables the tree to pre-prune a leaf via the tie-breaking mechanism.

## utils

Expand Down
119 changes: 66 additions & 53 deletions river/tree/extremely_fast_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,59 +451,67 @@ def _reevaluate_best_split(self, node, parent, branch_index, **kwargs):

# Manage memory
self._enforce_size_limit()

elif (
x_best.merit - x_current.merit > hoeffding_bound or hoeffding_bound < self.tau
) and (id_current != id_best):
# Create a new branch
branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split)
leaves = tuple(
self._new_leaf(initial_stats, parent=node)
for initial_stats in x_best.children_stats
)

new_split = x_best.assemble(branch, node.stats, node.depth, *leaves, **kwargs)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

n_active = n_inactive = 0
for leaf in node.iter_leaves():
if leaf.is_active():
n_active += 1
elif x_current is not None:
if (
x_best.merit - x_current.merit > hoeffding_bound
or hoeffding_bound < self.tau
) and (id_current != id_best):
# Create a new branch
branch = self._branch_selector(
x_best.numerical_feature, x_best.multiway_split
)
leaves = tuple(
self._new_leaf(initial_stats, parent=node)
for initial_stats in x_best.children_stats
)

new_split = x_best.assemble(
branch, node.stats, node.depth, *leaves, **kwargs
)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

n_active = n_inactive = 0
for leaf in node.iter_leaves():
if leaf.is_active():
n_active += 1
else:
n_inactive += 1

self._n_active_leaves -= n_active
self._n_inactive_leaves -= n_inactive
self._n_active_leaves += len(leaves)

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
n_inactive += 1

self._n_active_leaves -= n_active
self._n_inactive_leaves -= n_inactive
self._n_active_leaves += len(leaves)

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
parent.children[branch_index] = new_split

stop_flag = True

# Manage memory
self._enforce_size_limit()

elif (
x_best.merit - x_current.merit > hoeffding_bound or hoeffding_bound < self.tau
) and (id_current == id_best):
branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split)
# Change the branch but keep the existing children nodes
new_split = x_best.assemble(
branch, node.stats, node.depth, *tuple(node.children), **kwargs
)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
parent.children[branch_index] = new_split
parent.children[branch_index] = new_split

stop_flag = True

# Manage memory
self._enforce_size_limit()

elif (
x_best.merit - x_current.merit > hoeffding_bound
or hoeffding_bound < self.tau
) and (id_current == id_best):
branch = self._branch_selector(
x_best.numerical_feature, x_best.multiway_split
)
# Change the branch but keep the existing children nodes
new_split = x_best.assemble(
branch, node.stats, node.depth, *tuple(node.children), **kwargs
)
# Update weights in new_split
new_split.last_split_reevaluation_at = node.total_weight

if parent is None:
# Root case : replace the root node by a new split node
self._root = new_split
else:
parent.children[branch_index] = new_split

return stop_flag

Expand Down Expand Up @@ -551,7 +559,12 @@ def _attempt_to_split(self, node, parent, branch_index, **kwargs):
node.total_weight,
)

if x_best.merit - x_null.merit > hoeffding_bound or hoeffding_bound < self.tau:
if x_best.feature is None:
# Pre-pruning - null wins
node.deactivate()
self._n_inactive_leaves += 1
self._n_active_leaves -= 1
elif x_best.merit - x_null.merit > hoeffding_bound or hoeffding_bound < self.tau:
# Create a new branch
branch = self._branch_selector(x_best.numerical_feature, x_best.multiway_split)
leaves = tuple(
Expand Down