Skip to content

Commit

Permalink
Merge pull request #168 from mcschmitz/Fix-GreedyRuleListClassifier-b…
Browse files Browse the repository at this point in the history
…ug-raised-in-csinva/imodels#167

Fix greedy rule list classifier bug raised in #167
  • Loading branch information
csinva authored Mar 12, 2023
2 parents f3d6744 + 086c50e commit 1ac1c51
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
10 changes: 6 additions & 4 deletions imodels/rule_list/greedy_rule_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
import math
from copy import deepcopy

import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_array, check_is_fitted
from sklearn.tree import DecisionTreeClassifier
Expand Down Expand Up @@ -101,7 +99,7 @@ def fit_node_recursive(self, X, y, depth: int, verbose):
'col': self.feature_names_[col],
'index_col': col,
'cutoff': cutoff,
'val': np.mean(y), # values before splitting
'val': np.mean(y_left), # will be the values before splitting in the next lower level
'flip': flip,
'val_right': np.mean(y_right),
'num_pts': y.size,
Expand All @@ -128,7 +126,11 @@ def predict_proba(self, X):
for j, rule in enumerate(self.rules_):
if j == len(self.rules_) - 1:
probs[i] = rule['val']
elif x[rule['index_col']] >= rule['cutoff']:
continue
regular_condition = x[rule["index_col"]] >= rule["cutoff"]
flipped_condition = x[rule["index_col"]] < rule["cutoff"]
condition = flipped_condition if rule["flip"] else regular_condition
if condition:
probs[i] = rule['val_right']
break
return np.vstack((1 - probs, probs)).transpose() # probs (n, 2)
Expand Down
29 changes: 23 additions & 6 deletions tests/grl_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import unittest
import traceback

import numpy as np
from sklearn.metrics import accuracy_score
from imodels.rule_list.greedy_rule_list import GreedyRuleListClassifier
import sklearn
from sklearn.model_selection import train_test_split

class TestGRL(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.m = GreedyRuleListClassifier()

def test_integration_stability(self):
'''Test on synthetic dataset
Expand All @@ -22,11 +23,27 @@ def test_integration_stability(self):
[0, 1, 1, 1, 1],
[1, 0, 1, 1, 1]])
y = np.array([0, 0, 0, 0, 1, 1, 1, 1])
m = GreedyRuleListClassifier()
m.fit(X, y)
yhat = m.predict(X)
self.m.fit(X, y)
yhat = self.m.predict(X)
acc = np.mean(y == yhat) * 100
assert acc > 99, 'acc must be 100'
assert acc > 99 # acc must be 100

def test_linear_separability(self):
"""Test if the model can learn a linearly separable dataset"""
x = np.array([0.8, 0.8, 0.3, 0.3, 0.3, 0.3]).reshape(-1, 1)
y = np.array([0, 0, 1, 1, 1, 1])
self.m.fit(x, y, verbose=True)
yhat = self.m.predict(x)
acc = np.mean(y == yhat) * 100
assert len(self.m.rules_) == 2
assert acc == 100 # acc must be 100

def test_y_left_conditional_probability(self):
"""Test conditional probability of y given x in the left node"""
x = np.array([0.8, 0.8, 0.3, 0.3, 0.3, 0.3]).reshape(-1, 1)
y = np.array([0, 0, 1, 1, 1, 1])
self.m.fit(x, y, verbose=True)
assert self.m.rules_[1]["val"] == 0

def test_breast_cancer():
np.random.seed(13)
Expand Down

0 comments on commit 1ac1c51

Please sign in to comment.