Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
init
  • Loading branch information
zhengruifeng committed Jan 24, 2025
1 parent 2d74c3d commit a4ceb7e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 29 deletions.
27 changes: 18 additions & 9 deletions python/pyspark/ml/connect/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,23 @@ def build_float_list(value: List[float]) -> pb2.Expression.Literal:
return p


def build_proto_udt(jvm_class: str) -> pb2.DataType:
ret = pb2.DataType()
ret.udt.type = "udt"
ret.udt.jvm_class = jvm_class
return ret


proto_vector_udt = build_proto_udt("org.apache.spark.ml.linalg.VectorUDT")
proto_matrix_udt = build_proto_udt("org.apache.spark.ml.linalg.MatrixUDT")


def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.Literal:
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.connect.expressions import LiteralExpression

if isinstance(value, SparseVector):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
p.struct.struct_type.CopyFrom(proto_vector_udt)
# type = 0
p.struct.elements.append(pb2.Expression.Literal(byte=0))
# size
Expand All @@ -68,7 +78,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.

elif isinstance(value, DenseVector):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
p.struct.struct_type.CopyFrom(proto_vector_udt)
# type = 1
p.struct.elements.append(pb2.Expression.Literal(byte=1))
# size = null
Expand All @@ -81,7 +91,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.

elif isinstance(value, SparseMatrix):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
p.struct.struct_type.CopyFrom(proto_matrix_udt)
# type = 0
p.struct.elements.append(pb2.Expression.Literal(byte=0))
# numRows
Expand All @@ -100,7 +110,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.

elif isinstance(value, DenseMatrix):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
p.struct.struct_type.CopyFrom(proto_matrix_udt)
# type = 1
p.struct.elements.append(pb2.Expression.Literal(byte=1))
# numRows
Expand Down Expand Up @@ -134,14 +144,13 @@ def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:


def deserialize_param(literal: pb2.Expression.Literal) -> Any:
from pyspark.sql.connect.types import proto_schema_to_pyspark_data_type
from pyspark.sql.connect.expressions import LiteralExpression

if literal.HasField("struct"):
s = literal.struct
schema = proto_schema_to_pyspark_data_type(s.struct_type)
jvm_class = s.struct_type.udt.jvm_class

if schema == VectorUDT.sqlType():
if jvm_class == "org.apache.spark.ml.linalg.VectorUDT":
assert len(s.elements) == 4
tpe = s.elements[0].byte
if tpe == 0:
Expand All @@ -155,7 +164,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any:
else:
raise ValueError(f"Unknown Vector type {tpe}")

elif schema == MatrixUDT.sqlType():
elif jvm_class == "org.apache.spark.ml.linalg.MatrixUDT":
assert len(s.elements) == 7
tpe = s.elements[0].byte
if tpe == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.ml.regression._
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.{HasTrainingSummary, Identifiable, MLWritable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.SessionHolder
Expand Down Expand Up @@ -147,13 +147,11 @@ private[ml] object MLUtils {
val value = literal.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
val s = literal.getStruct
val schema = DataTypeProtoConverter.toCatalystType(s.getStructType)
if (schema == VectorUDT.sqlType) {
deserializeVector(s)
} else if (schema == MatrixUDT.sqlType) {
deserializeMatrix(s)
} else {
throw MlUnsupportedException(s"Unsupported parameter struct ${schema} for ${name}")
s.getStructType.getUdt.getJvmClass match {
case "org.apache.spark.ml.linalg.VectorUDT" => deserializeVector(s)
case "org.apache.spark.ml.linalg.MatrixUDT" => deserializeMatrix(s)
case _ =>
throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct} for ${name}")
}

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.Params
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter, ProtoDataTypes}
import org.apache.spark.sql.connect.common.{LiteralValueProtoConverter, ProtoDataTypes}
import org.apache.spark.sql.connect.service.SessionHolder

private[ml] object Serializer {
Expand All @@ -37,7 +37,7 @@ private[ml] object Serializer {
data match {
case v: SparseVector =>
val builder = proto.Expression.Literal.Struct.newBuilder()
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType))
builder.setStructType(ProtoDataTypes.VectorUDT)
// type = 0
builder.addElements(proto.Expression.Literal.newBuilder().setByte(0))
// size
Expand All @@ -50,7 +50,7 @@ private[ml] object Serializer {

case v: DenseVector =>
val builder = proto.Expression.Literal.Struct.newBuilder()
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType))
builder.setStructType(ProtoDataTypes.VectorUDT)
// type = 1
builder.addElements(proto.Expression.Literal.newBuilder().setByte(1))
// size = null
Expand All @@ -65,7 +65,7 @@ private[ml] object Serializer {

case m: SparseMatrix =>
val builder = proto.Expression.Literal.Struct.newBuilder()
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType))
builder.setStructType(ProtoDataTypes.MatrixUDT)
// type = 0
builder.addElements(proto.Expression.Literal.newBuilder().setByte(0))
// numRows
Expand All @@ -84,7 +84,7 @@ private[ml] object Serializer {

case m: DenseMatrix =>
val builder = proto.Expression.Literal.Struct.newBuilder()
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType))
builder.setStructType(ProtoDataTypes.MatrixUDT)
// type = 1
builder.addElements(proto.Expression.Literal.newBuilder().setByte(1))
// numRows
Expand Down Expand Up @@ -146,13 +146,13 @@ private[ml] object Serializer {
literal.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
val struct = literal.getStruct
val schema = DataTypeProtoConverter.toCatalystType(struct.getStructType)
if (schema == VectorUDT.sqlType) {
(MLUtils.deserializeVector(struct), classOf[Vector])
} else if (schema == MatrixUDT.sqlType) {
(MLUtils.deserializeMatrix(struct), classOf[Matrix])
} else {
throw MlUnsupportedException(s"$schema not supported")
struct.getStructType.getUdt.getJvmClass match {
case "org.apache.spark.ml.linalg.VectorUDT" =>
(MLUtils.deserializeVector(struct), classOf[Vector])
case "org.apache.spark.ml.linalg.MatrixUDT" =>
(MLUtils.deserializeMatrix(struct), classOf[Matrix])
case _ =>
throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct}")
}
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
(literal.getInteger.asInstanceOf[Object], classOf[Int])
Expand Down

0 comments on commit a4ceb7e

Please sign in to comment.