From 2a6d48b62aeb3ed932fba785c8265cdec2a1387f Mon Sep 17 00:00:00 2001 From: harrisonfloam Date: Wed, 27 Mar 2024 13:21:18 -0400 Subject: [PATCH] add _validate_input method to handle creation of classes_ attribute on fit call --- mlrose_hiive/neural/_nn_core.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mlrose_hiive/neural/_nn_core.py b/mlrose_hiive/neural/_nn_core.py index e66f02d5..8d23cc0d 100644 --- a/mlrose_hiive/neural/_nn_core.py +++ b/mlrose_hiive/neural/_nn_core.py @@ -6,6 +6,7 @@ import numpy as np from abc import abstractmethod +from sklearn.preprocessing import LabelBinarizer from mlrose_hiive.algorithms.decay import GeomDecay from mlrose_hiive.algorithms.rhc import random_hill_climb from mlrose_hiive.algorithms.sa import simulated_annealing @@ -107,6 +108,20 @@ def _validate(self): raise Exception("""Algorithm must be one of: 'random_hill_climb', 'simulated_annealing', 'genetic_alg', 'gradient_descent'.""") + + def _validate_input(self, X, y): + """ + Add _classes attribute based on classes present in y. + """ + + # Required for sk-learn 1.3+. Doesn't cause issues for lower versions. + # Copied from https://github.com/scikit-learn/scikit-learn/blob/5c4aa5d0d90ba66247d675d4c3fc2fdfba3c39ff/sklearn/neural_network/_multilayer_perceptron.py + # Note: no workaround found for multi-class labels, still doesn't work with f1 score. + + if (not hasattr(self, "classes_")): + self._label_binarizer = LabelBinarizer() + self._label_binarizer.fit(y) + self.classes_ = self._label_binarizer.classes_ def fit(self, X, y=None, init_weights=None): """Fit neural network to data. @@ -126,6 +141,7 @@ def fit(self, X, y=None, init_weights=None): If :code:`None`, then a random state is used. """ self._validate() + self._validate_input(X, y) X, y = self._format_x_y_data(X, y)