Skip to content

Commit

Permalink
[SPARK-50931][ML][PYTHON][CONNECT] Support Binarizer on connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support Binarizer on connect

### Why are the changes needed?
feature parity

### Does this PR introduce _any_ user-facing change?
yes, new algorithm

### How was this patch tested?
added test

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49637 from zhengruifeng/ml_connect_binarizer.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 24, 2025
1 parent 3377962 commit 6955bd5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# So register the supported transformer here if you're trying to add a new one.
########### Transformers
org.apache.spark.ml.feature.DCT
org.apache.spark.ml.feature.Binarizer
org.apache.spark.ml.feature.VectorAssembler
org.apache.spark.ml.feature.Tokenizer
org.apache.spark.ml.feature.RegexTokenizer
Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/ml/tests/connect/test_parity_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@


class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase):
@unittest.skip("Need to support.")
def test_binarizer(self):
super().test_binarizer()

@unittest.skip("Need to support.")
def test_idf(self):
super().test_idf()
Expand Down
40 changes: 40 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,46 @@ def test_binarizer(self):
self.assertEqual(b1.getInputCol(), "input")
self.assertEqual(b1.getOutputCol(), "output")

df = self.spark.createDataFrame(
[
(0.1, 0.0),
(0.4, 1.0),
(1.2, 1.3),
(1.5, float("nan")),
(float("nan"), 1.0),
(float("nan"), 0.0),
],
["v1", "v2"],
)

bucketizer = Binarizer(threshold=1.0, inputCol="v1", outputCol="f1")
output = bucketizer.transform(df)
self.assertEqual(output.columns, ["v1", "v2", "f1"])
self.assertEqual(output.count(), 6)
self.assertEqual(
[r.f1 for r in output.select("f1").collect()],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
)

bucketizer = Binarizer(threshold=1.0, inputCols=["v1", "v2"], outputCols=["f1", "f2"])
output = bucketizer.transform(df)
self.assertEqual(output.columns, ["v1", "v2", "f1", "f2"])
self.assertEqual(output.count(), 6)
self.assertEqual(
[r.f1 for r in output.select("f1").collect()],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
)
self.assertEqual(
[r.f2 for r in output.select("f2").collect()],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
)

# save & load
with tempfile.TemporaryDirectory(prefix="binarizer") as d:
bucketizer.write().overwrite().save(d)
bucketizer2 = Binarizer.load(d)
self.assertEqual(str(bucketizer), str(bucketizer2))

def test_idf(self):
dataset = self.spark.createDataFrame(
[(DenseVector([1.0, 2.0]),), (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)],
Expand Down

0 comments on commit 6955bd5

Please sign in to comment.