Skip to content

Commit

Permalink
Using check_array to be compatible with sklearn 0.16.1+
Browse files Browse the repository at this point in the history
see issue ejlb#2
  • Loading branch information
lkugler authored Mar 21, 2019
1 parent 2e32d43 commit 7009387
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions pegasos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@
limitations under the License.
"""


import warnings
from abc import ABCMeta, abstractmethod

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import atleast2d_or_csr
try:
from sklearn.utils import check_array
except ImportError as e:
warnings.warn(str(e)+'\n'
+ '`from sklearn.utils import check_array` failed.\n'
+ 'Using `from sklearn.utils import atleast2d_or_csr` instead.')
from sklearn.utils import atleast2d_or_csr as check_array
from scipy import sparse

from . import pegasos, constants
from .weight_vector import WeightVector

import numpy as np
import warnings

class PegasosBase(BaseEstimator, ClassifierMixin):
__metaclass__ = ABCMeta
Expand Down Expand Up @@ -62,7 +68,7 @@ def fit(self, X, y):
# training algorithm requires the labels to be -1 and +1.
y[y==0] = -1

X = atleast2d_or_csr(X, dtype=np.float64, order="C")
X = check_array(X, dtype=np.float64, order="C")

if X.shape[0] != y.shape[0]:
raise ValueError("X and y have incompatible shapes.\n"
Expand Down Expand Up @@ -113,4 +119,3 @@ def classes_(self):
if not hasattr(self, '_enc'):
raise ValueError('must call `fit` before `classes_`')
return self._enc.classes_

0 comments on commit 7009387

Please sign in to comment.