Skip to content

Commit

Permalink
Move parameter validation to separate method
Browse files Browse the repository at this point in the history
  • Loading branch information
betatim committed Oct 24, 2024
1 parent 9b0acde commit 8bc76b6
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions python/cuml/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,7 @@ class LogisticRegression(UniversalBase,
else:
self.verb_prefix = ""

@generate_docstring(X='dense_sparse')
@cuml.internals.api_base_return_any(set_output_dtype=True)
@enable_device_interop
def fit(self, X, y, sample_weight=None,
convert_dtype=True) -> "LogisticRegression":
"""
Fit the model with X and y.

"""
def _validate_params(self):
if self.penalty not in supported_penalties:
raise ValueError("`penalty` " + str(self.penalty) + " not supported.")

Expand All @@ -267,6 +259,17 @@ class LogisticRegression(UniversalBase,
msg = "l1_ratio value has to be between 0.0 and 1.0"
raise ValueError(msg.format(self.l1_ratio))

@generate_docstring(X='dense_sparse')
@cuml.internals.api_base_return_any(set_output_dtype=True)
@enable_device_interop
def fit(self, X, y, sample_weight=None,
convert_dtype=True) -> "LogisticRegression":
"""
Fit the model with X and y.

"""
self._validate_params()

l1_strength, l2_strength = self._get_qn_params()
self.solver_model = QN(
loss="sigmoid",
Expand All @@ -278,6 +281,7 @@ class LogisticRegression(UniversalBase,
tol=self.tol,
verbose=self.verbose,
handle=self.handle,
output_type=self.output_type,
)

self.n_features_in_ = X.shape[1] if X.ndim == 2 else 1
Expand Down

0 comments on commit 8bc76b6

Please sign in to comment.