Skip to content

Commit

Permalink
add _validate_input method to handle creation of classes_ attribute o…
Browse files Browse the repository at this point in the history
…n fit call
  • Loading branch information
harrisonfloam committed Mar 27, 2024
1 parent c23243b commit 2a6d48b
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions mlrose_hiive/neural/_nn_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from abc import abstractmethod
from sklearn.preprocessing import LabelBinarizer
from mlrose_hiive.algorithms.decay import GeomDecay
from mlrose_hiive.algorithms.rhc import random_hill_climb
from mlrose_hiive.algorithms.sa import simulated_annealing
Expand Down Expand Up @@ -107,6 +108,20 @@ def _validate(self):
raise Exception("""Algorithm must be one of: 'random_hill_climb',
'simulated_annealing', 'genetic_alg',
'gradient_descent'.""")

def _validate_input(self, X, y):
"""
Add _classes attribute based on classes present in y.
"""

# Required for sk-learn 1.3+. Doesn't cause issues for lower versions.
# Copied from https://github.com/scikit-learn/scikit-learn/blob/5c4aa5d0d90ba66247d675d4c3fc2fdfba3c39ff/sklearn/neural_network/_multilayer_perceptron.py
# Note: no workaround found for multi-class labels, still doesn't work with f1 score.

if (not hasattr(self, "classes_")):
self._label_binarizer = LabelBinarizer()
self._label_binarizer.fit(y)
self.classes_ = self._label_binarizer.classes_

def fit(self, X, y=None, init_weights=None):
"""Fit neural network to data.
Expand All @@ -126,6 +141,7 @@ def fit(self, X, y=None, init_weights=None):
If :code:`None`, then a random state is used.
"""
self._validate()
self._validate_input(X, y)

X, y = self._format_x_y_data(X, y)

Expand Down

0 comments on commit 2a6d48b

Please sign in to comment.