diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 23e422ab..1bb34355 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -3,13 +3,13 @@ import abc import hashlib import logging -import pickle - import networkx as nx import numpy as np +import pickle from joblib import Parallel, delayed from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import _check_sample_weight try: @@ -215,7 +215,10 @@ def _disambiguate(self): child = str(self.y_[i, j]) row.append(parent + self.separator_ + child) new_y.append(np.asarray(row, dtype=np.str_)) - self.y_ = np.array(new_y) + new_y = np.array(new_y) + self.label_encoder_ = LabelEncoder() + self.label_encoder_.fit(new_y) + self.y_ = self.label_encoder_.transform(new_y) def _create_digraph(self): # Create DiGraph diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 2420ac36..907e61cf 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -5,13 +5,13 @@ """ import hashlib -import numpy as np import pickle from copy import deepcopy -from joblib import Parallel, delayed from os.path import exists + +import numpy as np +from joblib import Parallel, delayed from sklearn.base import BaseEstimator -from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -273,9 +273,6 @@ def _fit_classifier(self, level, separator): classifier = ConstantClassifier() if not self.bert: try: - label_encoder = LabelEncoder() - label_encoder.fit(y) - y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 77d674a5..47f77475 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -5,13 +5,13 @@ """ import hashlib -import networkx as nx -import numpy as np import pickle from copy import deepcopy from os.path import exists + +import networkx as nx +import numpy as np from sklearn.base import BaseEstimator -from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -231,9 +231,6 @@ def _fit_classifier(self, node): classifier = ConstantClassifier() if not self.bert: try: - label_encoder = LabelEncoder() - label_encoder.fit(y) - y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y)