Skip to content

Commit

Permalink
[SPARK-50963][ML][PYTHON][TESTS][FOLLOW-UP] Enable a parity test
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Enable a existing test on connect, move it after `test_stop_words_remover` and rename it

### Why are the changes needed?
for test coverage

### Does this PR introduce _any_ user-facing change?
no, test-only

### How was this patch tested?
parity test

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49642 from zhengruifeng/ml_remover_test.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 24, 2025
1 parent 560dd5e commit 3377962
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 35 deletions.
4 changes: 0 additions & 4 deletions python/pyspark/ml/tests/connect/test_parity_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def test_idf(self):
def test_ngram(self):
super().test_ngram()

@unittest.skip("Need to support.")
def test_stopwordsremover(self):
super().test_stopwordsremover()

@unittest.skip("Need to support.")
def test_count_vectorizer_with_binary(self):
super().test_count_vectorizer_with_binary()
Expand Down
62 changes: 31 additions & 31 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,37 @@ def test_stop_words_remover(self):
remover2 = StopWordsRemover.load(d)
self.assertEqual(str(remover), str(remover2))

def test_stop_words_remover_II(self):
dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
# Default
self.assertEqual(stopWordRemover.getInputCol(), "input")
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["panda"])
self.assertEqual(type(stopWordRemover.getStopWords()), list)
self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], str))
# Custom
stopwords = ["panda"]
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getInputCol(), "input")
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])
# with language selection
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
# with locale
stopwords = ["BELKİ"]
dataset = self.spark.createDataFrame([Row(input=["belki"])])
stopWordRemover.setStopWords(stopwords).setLocale("tr")
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])

def test_binarizer(self):
b0 = Binarizer()
self.assertListEqual(
Expand Down Expand Up @@ -570,37 +601,6 @@ def test_ngram(self):
transformedDF = ngram0.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"])

def test_stopwordsremover(self):
dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
# Default
self.assertEqual(stopWordRemover.getInputCol(), "input")
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["panda"])
self.assertEqual(type(stopWordRemover.getStopWords()), list)
self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], str))
# Custom
stopwords = ["panda"]
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getInputCol(), "input")
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])
# with language selection
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
# with locale
stopwords = ["BELKİ"]
dataset = self.spark.createDataFrame([Row(input=["belki"])])
stopWordRemover.setStopWords(stopwords).setLocale("tr")
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])

def test_count_vectorizer_with_binary(self):
dataset = self.spark.createDataFrame(
[
Expand Down

0 comments on commit 3377962

Please sign in to comment.