diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator index eec895fea557f..ca728566490de 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator @@ -58,3 +58,6 @@ org.apache.spark.ml.feature.VarianceThresholdSelector org.apache.spark.ml.feature.StringIndexer org.apache.spark.ml.feature.PCA org.apache.spark.ml.feature.Word2Vec +org.apache.spark.ml.feature.CountVectorizer +org.apache.spark.ml.feature.OneHotEncoder + diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index 491fa495c43fb..8899ed572ab62 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -62,3 +62,6 @@ org.apache.spark.ml.feature.VarianceThresholdSelectorModel org.apache.spark.ml.feature.StringIndexerModel org.apache.spark.ml.feature.PCAModel org.apache.spark.ml.feature.Word2VecModel +org.apache.spark.ml.feature.CountVectorizerModel +org.apache.spark.ml.feature.OneHotEncoderModel + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 611b5c710add1..95788be6bd2bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -277,6 +277,8 @@ class CountVectorizerModel( import CountVectorizerModel._ + private[ml] def this() = this(Identifiable.randomUID("cntVecModel"), Array.empty) + @Since("1.5.0") def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 44b8b2047681b..25bcdc9a1c293 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -234,6 +234,8 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ + private[ml] def this() = this(Identifiable.randomUID("oneHotEncoder)"), Array.emptyIntArray) + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. private def getConfigedCategorySizes: Array[Int] = { diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index fe7d50d26951f..a969ea34247f3 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -27,6 +27,8 @@ Binarizer, CountVectorizer, CountVectorizerModel, + OneHotEncoder, + OneHotEncoderModel, HashingTF, IDF, NGram, @@ -535,6 +537,61 @@ def test_word2vec(self): model2 = Word2VecModel.load(d) self.assertEqual(str(model), str(model2)) + def test_count_vectorizer(self): + df = self.spark.createDataFrame( + [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ["label", "raw"], + ) + + cv = CountVectorizer() + cv.setInputCol("raw") + cv.setOutputCol("vectors") + self.assertEqual(cv.getInputCol(), "raw") + self.assertEqual(cv.getOutputCol(), "vectors") + + model = cv.fit(df) + self.assertEqual(sorted(model.vocabulary), ["a", "b", "c"]) + + output = model.transform(df) + self.assertEqual(output.columns, ["label", "raw", "vectors"]) + self.assertEqual(output.count(), 2) + + # save & load + with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d: + cv.write().overwrite().save(d) + cv2 = CountVectorizer.load(d) + self.assertEqual(str(cv), str(cv2)) + + model.write().overwrite().save(d) + model2 = CountVectorizerModel.load(d) + self.assertEqual(str(model), str(model2)) + + def test_one_hot_encoder(self): + df = self.spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) + + encoder = OneHotEncoder() + encoder.setInputCols(["input"]) + encoder.setOutputCols(["output"]) + self.assertEqual(encoder.getInputCols(), ["input"]) + self.assertEqual(encoder.getOutputCols(), ["output"]) + + model = encoder.fit(df) + self.assertEqual(model.categorySizes, [3]) + + output = model.transform(df) + self.assertEqual(output.columns, ["input", "output"]) + self.assertEqual(output.count(), 3) + + # save & load + with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d: + encoder.write().overwrite().save(d) + encoder2 = OneHotEncoder.load(d) + self.assertEqual(str(encoder), str(encoder2)) + + model.write().overwrite().save(d) + model2 = OneHotEncoderModel.load(d) + self.assertEqual(str(model), str(model2)) + def test_tokenizer(self): df = self.spark.createDataFrame([("a b c",)], ["text"]) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index a4a8350ef4209..4c45c14853cbe 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -591,7 +591,9 @@ private[ml] object MLUtils { (classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")), (classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")), (classOf[PCAModel], Set("pc", "explainedVariance")), - (classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray"))) + (classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")), + (classOf[CountVectorizerModel], Set("vocabulary")), + (classOf[OneHotEncoderModel], Set("categorySizes"))) private def validate(obj: Any, method: String): Unit = { assert(obj != null)