diff --git a/pyproject.toml b/pyproject.toml index 54d1914..448df7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "time-robust-forest" -version = "0.1.9" +version = "0.1.10" description = "Explores time information to train a robust random forest" readme = "README.md" authors = [ diff --git a/time_robust_forest/models.py b/time_robust_forest/models.py index 53d5ea7..12cc4f4 100644 --- a/time_robust_forest/models.py +++ b/time_robust_forest/models.py @@ -431,6 +431,7 @@ def __init__( split_verbose=False, impurity_verbose=False, random_state=42, + rng=None, ): if len(row_indexes) == 0: row_indexes = np.arange(len(y)) @@ -470,7 +471,11 @@ def __init__( self.period_criterion = period_criterion self.min_impurity_decrease = min_impurity_decrease self.total_sample = total_sample - self.rng = default_rng(random_state) + self.random_state = random_state + if rng == None: + self.rng = default_rng(self.random_state) + else: + self.rng = rng if sample_weight is not None: self.sample_weight = sample_weight @@ -555,6 +560,8 @@ def create_split(self): verbose=self.verbose, split_verbose=self.split_verbose, impurity_verbose=self.impurity_verbose, + random_state=self.random_state, + rng=self.rng, ) self.right_split = _RandomTimeSplitTree( self.X, @@ -574,6 +581,8 @@ def create_split(self): verbose=self.verbose, split_verbose=self.split_verbose, impurity_verbose=self.impurity_verbose, + random_state=self.random_state, + rng=self.rng, ) def find_better_split(self, variable, variable_idx):