Skip to content

Commit

Permalink
[SPARK-51004][ML][PYTHON][CONNECT] Add supports for IndexString
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR add supports for IndexString and add labels/labelsArray to ALLOWED_LIST.

### Why are the changes needed?

new feature parity and bug fix

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
CI passes

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49690 from wbo4958/index-str.

Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
wbo4958 authored and zhengruifeng committed Jan 27, 2025
1 parent b0e18ba commit b5deb8d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ org.apache.spark.ml.feature.SQLTransformer
org.apache.spark.ml.feature.StopWordsRemover
org.apache.spark.ml.feature.FeatureHasher
org.apache.spark.ml.feature.HashingTF
org.apache.spark.ml.feature.IndexToString

########### Model for loading
# classification
Expand Down
48 changes: 48 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
BucketedRandomProjectionLSHModel,
MinHashLSH,
MinHashLSHModel,
IndexToString,
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
Expand All @@ -80,6 +81,51 @@


class FeatureTestsMixin:
def test_index_string(self):
dataset = self.spark.createDataFrame(
[
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c"),
],
["id", "label"],
)

indexer = StringIndexer(inputCol="label", outputCol="labelIndex").fit(dataset)
transformed = indexer.transform(dataset)
idx2str = (
IndexToString()
.setInputCol("labelIndex")
.setOutputCol("sameLabel")
.setLabels(indexer.labels)
)

def check(t: IndexToString) -> None:
self.assertEqual(t.getInputCol(), "labelIndex")
self.assertEqual(t.getOutputCol(), "sameLabel")
self.assertEqual(t.getLabels(), indexer.labels)

check(idx2str)

ret = idx2str.transform(transformed)
self.assertEqual(
sorted(ret.schema.names), sorted(["id", "label", "labelIndex", "sameLabel"])
)

rows = ret.select("label", "sameLabel").collect()
for r in rows:
self.assertEqual(r.label, r.sameLabel)

# save & load
with tempfile.TemporaryDirectory(prefix="index_string") as d:
idx2str.write().overwrite().save(d)
idx2str2 = IndexToString.load(d)
self.assertEqual(str(idx2str), str(idx2str2))
check(idx2str2)

def test_dct(self):
df = self.spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])
dct = DCT()
Expand Down Expand Up @@ -128,6 +174,7 @@ def test_string_indexer(self):
si = StringIndexer(inputCol="label1", outputCol="index1")
model = si.fit(df.select("label1"))
self.assertEqual(si.uid, model.uid)
self.assertEqual(model.labels, list(model.labelsArray[0]))

# read/write
with tempfile.TemporaryDirectory(prefix="string_indexer") as tmp_dir:
Expand Down Expand Up @@ -188,6 +235,7 @@ def test_pca(self):
pca = PCA(k=2, inputCol="features", outputCol="pca_features")

model = pca.fit(df)
self.assertTrue(np.allclose(model.pc.toArray()[0], [-0.44859172, -0.28423808], atol=1e-4))
self.assertEqual(pca.uid, model.uid)
self.assertEqual(model.getK(), 2)
self.assertTrue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ private[ml] object MLUtils {
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")),
(classOf[CountVectorizerModel], Set("vocabulary")),
(classOf[OneHotEncoderModel], Set("categorySizes")),
(classOf[StringIndexerModel], Set("labels", "labelsArray")),
(classOf[IDFModel], Set("idf", "docFreq", "numDocs")))

private def validate(obj: Any, method: String): Unit = {
Expand Down

0 comments on commit b5deb8d

Please sign in to comment.