Skip to content

Commit

Permalink
[SPARK-50884][ML][PYTHON][CONNECT] Support isLargerBetter in Evaluator
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support isLargerBetter in Evaluator

### Why are the changes needed?
for parity feature

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
The newly added tests pass

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49620 from wbo4958/isLargerBetter.

Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
wbo4958 authored and zhengruifeng committed Jan 24, 2025
1 parent 8f66aef commit 311a4e0
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ def setParams(
kwargs = self._input_kwargs
return self._set(**kwargs)

def isLargerBetter(self) -> bool:
"""Override this function to make it run on connect"""
return True


@inherit_doc
class RegressionEvaluator(
Expand Down Expand Up @@ -467,6 +471,10 @@ def setParams(
kwargs = self._input_kwargs
return self._set(**kwargs)

def isLargerBetter(self) -> bool:
"""Override this function to make it run on connect"""
return self.getMetricName() in ["r2", "var"]


@inherit_doc
class MulticlassClassificationEvaluator(
Expand Down Expand Up @@ -700,6 +708,15 @@ def setParams(
kwargs = self._input_kwargs
return self._set(**kwargs)

def isLargerBetter(self) -> bool:
"""Override this function to make it run on connect"""
return not self.getMetricName() in [
"weightedFalsePositiveRate",
"falsePositiveRateByLabel",
"logLoss",
"hammingLoss",
]


@inherit_doc
class MultilabelClassificationEvaluator(
Expand Down Expand Up @@ -843,6 +860,10 @@ def setParams(
kwargs = self._input_kwargs
return self._set(**kwargs)

def isLargerBetter(self) -> bool:
"""Override this function to make it run on connect"""
return self.getMetricName() != "hammingLoss"


@inherit_doc
class ClusteringEvaluator(
Expand Down Expand Up @@ -1002,6 +1023,10 @@ def setWeightCol(self, value: str) -> "ClusteringEvaluator":
"""
return self._set(weightCol=value)

def isLargerBetter(self) -> bool:
"""Override this function to make it run on connect"""
return True


@inherit_doc
class RankingEvaluator(
Expand Down Expand Up @@ -1138,6 +1163,10 @@ def setParams(
kwargs = self._input_kwargs
return self._set(**kwargs)

def isLargerBetter(self) -> bool:
"""Override this function to make it run on connect"""
return True


if __name__ == "__main__":
import doctest
Expand Down
54 changes: 54 additions & 0 deletions python/pyspark/ml/tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_ranking_evaluator(self):

# Initialize RankingEvaluator
evaluator = RankingEvaluator().setPredictionCol("prediction")
self.assertTrue(evaluator.isLargerBetter())

# Evaluate the dataset using the default metric (mean average precision)
mean_average_precision = evaluator.evaluate(dataset)
Expand Down Expand Up @@ -94,6 +95,25 @@ def test_multilabel_classification_evaluator(self):
self.assertEqual(evaluator2.getPredictionCol(), "prediction")
self.assertEqual(str(evaluator), str(evaluator2))

for metric in [
"subsetAccuracy",
"accuracy",
"precision",
"recall",
"f1Measure",
"precisionByLabel",
"recallByLabel",
"f1MeasureByLabel",
"microPrecision",
"microRecall",
"microF1Measure",
]:
evaluator.setMetricName(metric)
self.assertTrue(evaluator.isLargerBetter())

evaluator.setMetricName("hammingLoss")
self.assertTrue(not evaluator.isLargerBetter())

def test_multiclass_classification_evaluator(self):
dataset = self.spark.createDataFrame(
[
Expand Down Expand Up @@ -163,6 +183,29 @@ def test_multiclass_classification_evaluator(self):
log_loss = evaluator.evaluate(dataset)
self.assertTrue(np.allclose(log_loss, 1.0093, atol=1e-4))

for metric in [
"f1",
"accuracy",
"weightedPrecision",
"weightedRecall",
"weightedTruePositiveRate",
"weightedFMeasure",
"truePositiveRateByLabel",
"precisionByLabel",
"recallByLabel",
"fMeasureByLabel",
]:
evaluator.setMetricName(metric)
self.assertTrue(evaluator.isLargerBetter())
for metric in [
"weightedFalsePositiveRate",
"falsePositiveRateByLabel",
"logLoss",
"hammingLoss",
]:
evaluator.setMetricName(metric)
self.assertTrue(not evaluator.isLargerBetter())

def test_binary_classification_evaluator(self):
# Define score and labels data
data = map(
Expand All @@ -180,6 +223,8 @@ def test_binary_classification_evaluator(self):
dataset = self.spark.createDataFrame(data, ["raw", "label", "weight"])

evaluator = BinaryClassificationEvaluator().setRawPredictionCol("raw")
self.assertTrue(evaluator.isLargerBetter())

auc_roc = evaluator.evaluate(dataset)
self.assertTrue(np.allclose(auc_roc, 0.7083, atol=1e-4))

Expand Down Expand Up @@ -226,6 +271,8 @@ def test_clustering_evaluator(self):
dataset = self.spark.createDataFrame(data, ["features", "prediction", "weight"])

evaluator = ClusteringEvaluator().setPredictionCol("prediction")
self.assertTrue(evaluator.isLargerBetter())

score = evaluator.evaluate(dataset)
self.assertTrue(np.allclose(score, 0.9079, atol=1e-4))

Expand Down Expand Up @@ -300,6 +347,13 @@ def test_regression_evaluator(self):
through_origin = evaluator_with_weights.getThroughOrigin()
self.assertEqual(through_origin, False)

for metric in ["mse", "rmse", "mae"]:
evaluator.setMetricName(metric)
self.assertTrue(not evaluator.isLargerBetter())
for metric in ["r2", "var"]:
evaluator.setMetricName(metric)
self.assertTrue(evaluator.isLargerBetter())


class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 311a4e0

Please sign in to comment.