Skip to content

Commit

Permalink
[SPARK-50935][ML][PYTHON][CONNECT] Support DCT on connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support DCT on connect

### Why are the changes needed?
For feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49626 from zhengruifeng/ml_connect_dct.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 24, 2025
1 parent 5c26684 commit 2d74c3d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
########### Transformers
org.apache.spark.ml.feature.DCT
org.apache.spark.ml.feature.VectorAssembler
org.apache.spark.ml.feature.Tokenizer
org.apache.spark.ml.feature.RegexTokenizer
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

from pyspark.ml.feature import (
DCT,
Binarizer,
CountVectorizer,
CountVectorizerModel,
Expand Down Expand Up @@ -59,6 +60,34 @@


class FeatureTestsMixin:
def test_dct(self):
df = self.spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])
dct = DCT()
dct.setInverse(False)
dct.setInputCol("vec")
dct.setOutputCol("resultVec")

self.assertFalse(dct.getInverse())
self.assertEqual(dct.getInputCol(), "vec")
self.assertEqual(dct.getOutputCol(), "resultVec")

output = dct.transform(df)
self.assertEqual(output.columns, ["vec", "resultVec"])
self.assertEqual(output.count(), 1)
self.assertTrue(
np.allclose(
output.head().resultVec.toArray(),
[10.96965511, -0.70710678, -2.04124145],
atol=1e-4,
)
)

# save & load
with tempfile.TemporaryDirectory(prefix="dct") as d:
dct.write().overwrite().save(d)
dct2 = DCT.load(d)
self.assertEqual(str(dct), str(dct2))

def test_string_indexer(self):
df = (
self.spark.createDataFrame(
Expand Down

0 comments on commit 2d74c3d

Please sign in to comment.