From db2ae5d8960b411d476e0d997beff6450a15eabd Mon Sep 17 00:00:00 2001 From: Hubert Baniecki Date: Wed, 28 Feb 2024 21:15:02 +0100 Subject: [PATCH] [python] try to fix gh-actions for windows --- .../dalex/test/test_arena_classification.py | 29 +++++++++++++++++-- python/dalex/test/test_ceteris_paribus.py | 2 +- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/dalex/test/test_arena_classification.py b/python/dalex/test/test_arena_classification.py index 394cc2c9..c361dd36 100644 --- a/python/dalex/test/test_arena_classification.py +++ b/python/dalex/test/test_arena_classification.py @@ -22,7 +22,8 @@ def setUp(self): self.X = data.drop(columns='survived') self.y = data.survived - self.john = pd.DataFrame({'gender': ['male'], 'age': [25], 'class': ['1st'], 'embarked': ['Southampton'], 'fare': [72], 'sibsp': [0], 'parch': 0}, index = ['John']) + self.john = pd.DataFrame({'gender': ['male'], 'age': [25], 'class': ['1st'], 'embarked': ['Southampton'], + 'fare': [72], 'sibsp': [0], 'parch': 0}, index = ['John']) numeric_features = ['age', 'fare', 'sibsp', 'parch'] numeric_transformer = Pipeline(steps=[ @@ -66,6 +67,11 @@ def test_supported_plots(self): sorting = lambda x: x.__name__ self.assertEqual(sorted(plots, key=sorting), sorted(self.reference_plots, key=sorting)) + try: + arena.stop_server() + except Exception: + pass + def test_server(self): arena = dx.Arena() arena.push_model(self.exp) @@ -78,7 +84,11 @@ def test_server(self): arena.stop_server() except AssertionError as e: arena.stop_server() - raise e + + try: + arena.stop_server() + except Exception: + pass def test_plots(self): arena = dx.Arena() @@ -91,6 +101,11 @@ def test_plots(self): count = np.sum([1 for plot in arena.plots_manager.cache if plot.__class__ == p]) self.assertEqual(np.prod(ref_counts), count, msg="Count of " + str(p)) + try: + arena.stop_server() + except Exception: + pass + def test_observation_attributes(self): arena = dx.Arena() arena.push_model(self.exp) @@ -103,6 +118,11 @@ def test_observation_attributes(self): for attr in attrs: self.assertTrue(all(attr.get('values')[:-1] == titanic[attr.get('name')])) + try: + arena.stop_server() + except Exception: + pass + def test_variable_attributes(self): arena = dx.Arena() arena.push_model(self.exp) @@ -119,5 +139,10 @@ def test_variable_attributes(self): 'sibsp': {'type': 'numeric', 'min': 0, 'max': 8, 'levels': [0, 1, 2, 3, 4, 5, 8]} }) + try: + arena.stop_server() + except Exception: + pass + if __name__ == '__main__': unittest.main() diff --git a/python/dalex/test/test_ceteris_paribus.py b/python/dalex/test/test_ceteris_paribus.py index 68fc5801..eb0e84ac 100644 --- a/python/dalex/test/test_ceteris_paribus.py +++ b/python/dalex/test/test_ceteris_paribus.py @@ -184,7 +184,7 @@ def test_plot(self): fig3 = case1.plot(case2, variables="age", show=False) fig4 = case2.plot(variables="gender", show=False) fig5 = case1.plot(case3, size=1, color="gender", facet_ncol=1, show_observations=False, - title="title", horizontal_spacing=0.2, vertical_spacing=0.2, + title="title", horizontal_spacing=0.2, vertical_spacing=0.15, show=False) fig6 = case2.plot(variables=["gender"], show=False) fig7 = case2.plot(variables=["gender", 'class'], show=False)