Skip to content

Commit

Permalink
[SPARK-50919][ML][PYTHON][CONNECT] Support LinearSVC on connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support LinearSVC on connect

### Why are the changes needed?
For feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49625 from zhengruifeng/ml_connect_svc.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 24, 2025
1 parent 42b15c9 commit 5c26684
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ sealed trait Vector extends Serializable {
@Since("2.0.0")
object Vectors {

private[ml] val empty: Vector = zeros(0)
private[ml] val empty: DenseVector = new DenseVector(Array.emptyDoubleArray)

/**
* Creates a dense vector from its values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# So register the supported estimator here if you're trying to add a new one.

# classification
org.apache.spark.ml.classification.LinearSVC
org.apache.spark.ml.classification.LogisticRegression
org.apache.spark.ml.classification.DecisionTreeClassifier
org.apache.spark.ml.classification.RandomForestClassifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ org.apache.spark.ml.feature.StopWordsRemover

########### Model for loading
# classification
org.apache.spark.ml.classification.LinearSVCModel
org.apache.spark.ml.classification.LogisticRegressionModel
org.apache.spark.ml.classification.DecisionTreeClassificationModel
org.apache.spark.ml.classification.RandomForestClassificationModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ class LinearSVCModel private[classification] (
extends ClassificationModel[Vector, LinearSVCModel]
with LinearSVCParams with MLWritable with HasTrainingSummary[LinearSVCTrainingSummary] {

private[ml] def this() = this(Identifiable.randomUID("linearsvc"), Vectors.empty, 0.0)

@Since("2.2.0")
override val numClasses: Int = 2

Expand Down
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class PCAModel private[ml] (
// For ml connect only
@Since("4.0.0")
private[ml] def this() = this(Identifiable.randomUID("pca"),
DenseMatrix.zeros(1, 1), Vectors.empty.asInstanceOf[DenseVector])
DenseMatrix.zeros(1, 1), Vectors.empty)

/** @group setParam */
@Since("1.5.0")
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
HasSolver,
HasParallelism,
)
from pyspark.ml.util import try_remote_attribute_relation
from pyspark.ml.tree import (
_DecisionTreeModel,
_DecisionTreeParams,
Expand All @@ -86,6 +85,7 @@
MLWriter,
MLWritable,
HasTrainingSummary,
try_remote_attribute_relation,
)
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
Expand Down
76 changes: 76 additions & 0 deletions python/pyspark/ml/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from pyspark.ml.linalg import Vectors, Matrices
from pyspark.sql import SparkSession, DataFrame
from pyspark.ml.classification import (
LinearSVC,
LinearSVCModel,
LinearSVCSummary,
LinearSVCTrainingSummary,
LogisticRegression,
LogisticRegressionModel,
LogisticRegressionSummary,
Expand Down Expand Up @@ -299,6 +303,78 @@ def test_logistic_regression(self):
except OSError:
pass

def test_linear_svc(self):
df = (
self.spark.createDataFrame(
[
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
(1.0, 3.0, Vectors.dense(2.0, 1.0)),
(0.0, 4.0, Vectors.dense(3.0, 3.0)),
],
["label", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("weight")
)

svc = LinearSVC(maxIter=1, regParam=1.0)
self.assertEqual(svc.getMaxIter(), 1)
self.assertEqual(svc.getRegParam(), 1.0)

model = svc.fit(df)
self.assertEqual(model.numClasses, 2)
self.assertEqual(model.numFeatures, 2)
self.assertTrue(np.allclose(model.intercept, 0.025877458475338313, atol=1e-4))
self.assertTrue(
np.allclose(model.coefficients.toArray(), [-0.03622844, 0.01035098], atol=1e-4)
)

vec = Vectors.dense(0.0, 5.0)
self.assertEqual(model.predict(vec), 1.0)
self.assertTrue(
np.allclose(model.predictRaw(vec).toArray(), [-0.07763238, 0.07763238], atol=1e-4)
)

output = model.transform(df)
expected_cols = [
"label",
"weight",
"features",
"rawPrediction",
"prediction",
]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 4)

# model summary
self.assertTrue(model.hasSummary)
summary = model.summary()
self.assertIsInstance(summary, LinearSVCSummary)
self.assertIsInstance(summary, LinearSVCTrainingSummary)
self.assertEqual(summary.labels, [0.0, 1.0])
self.assertEqual(summary.accuracy, 0.5)
self.assertEqual(summary.areaUnderROC, 0.75)
self.assertEqual(summary.predictions.columns, expected_cols)

summary2 = model.evaluate(df)
self.assertIsInstance(summary2, LinearSVCSummary)
self.assertFalse(isinstance(summary2, LinearSVCTrainingSummary))
self.assertEqual(summary2.labels, [0.0, 1.0])
self.assertEqual(summary2.accuracy, 0.5)
self.assertEqual(summary2.areaUnderROC, 0.75)
self.assertEqual(summary2.predictions.columns, expected_cols)

# Model save & load
with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
svc.write().overwrite().save(d)
svc2 = LinearSVC.load(d)
self.assertEqual(str(svc), str(svc2))

model.write().overwrite().save(d)
model2 = LinearSVCModel.load(d)
self.assertEqual(str(model), str(model2))

def test_decision_tree_classifier(self):
df = (
self.spark.createDataFrame(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ private[ml] object MLUtils {
(classOf[GBTRegressionModel], Set("featureImportances", "evaluateEachIteration")),

// Classification Models
(classOf[LinearSVCModel], Set("intercept", "coefficients", "evaluate")),
(
classOf[LogisticRegressionModel],
Set("intercept", "coefficients", "interceptVector", "coefficientMatrix", "evaluate")),
Expand Down

0 comments on commit 5c26684

Please sign in to comment.