diff --git a/python/pyspark/ml/connect/serialize.py b/python/pyspark/ml/connect/serialize.py index 62b21460feb7c..8635f22841a75 100644 --- a/python/pyspark/ml/connect/serialize.py +++ b/python/pyspark/ml/connect/serialize.py @@ -49,13 +49,23 @@ def build_float_list(value: List[float]) -> pb2.Expression.Literal: return p +def build_proto_udt(jvm_class: str) -> pb2.DataType: + ret = pb2.DataType() + ret.udt.type = "udt" + ret.udt.jvm_class = jvm_class + return ret + + +proto_vector_udt = build_proto_udt("org.apache.spark.ml.linalg.VectorUDT") +proto_matrix_udt = build_proto_udt("org.apache.spark.ml.linalg.MatrixUDT") + + def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.Literal: - from pyspark.sql.connect.types import pyspark_types_to_proto_types from pyspark.sql.connect.expressions import LiteralExpression if isinstance(value, SparseVector): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_vector_udt) # type = 0 p.struct.elements.append(pb2.Expression.Literal(byte=0)) # size @@ -68,7 +78,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression. elif isinstance(value, DenseVector): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_vector_udt) # type = 1 p.struct.elements.append(pb2.Expression.Literal(byte=1)) # size = null @@ -81,7 +91,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression. elif isinstance(value, SparseMatrix): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_matrix_udt) # type = 0 p.struct.elements.append(pb2.Expression.Literal(byte=0)) # numRows @@ -100,7 +110,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression. elif isinstance(value, DenseMatrix): p = pb2.Expression.Literal() - p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType())) + p.struct.struct_type.CopyFrom(proto_matrix_udt) # type = 1 p.struct.elements.append(pb2.Expression.Literal(byte=1)) # numRows @@ -134,14 +144,13 @@ def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]: def deserialize_param(literal: pb2.Expression.Literal) -> Any: - from pyspark.sql.connect.types import proto_schema_to_pyspark_data_type from pyspark.sql.connect.expressions import LiteralExpression if literal.HasField("struct"): s = literal.struct - schema = proto_schema_to_pyspark_data_type(s.struct_type) + jvm_class = s.struct_type.udt.jvm_class - if schema == VectorUDT.sqlType(): + if jvm_class == "org.apache.spark.ml.linalg.VectorUDT": assert len(s.elements) == 4 tpe = s.elements[0].byte if tpe == 0: @@ -155,7 +164,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any: else: raise ValueError(f"Unknown Vector type {tpe}") - elif schema == MatrixUDT.sqlType(): + elif jvm_class == "org.apache.spark.ml.linalg.MatrixUDT": assert len(s.elements) == 7 tpe = s.elements[0].byte if tpe == 0: diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index 926fc23621634..e784dfd504d3f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -38,7 +38,7 @@ import org.apache.spark.ml.regression._ import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.{HasTrainingSummary, Identifiable, MLWritable} import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter} +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.connect.service.SessionHolder @@ -147,13 +147,11 @@ private[ml] object MLUtils { val value = literal.getLiteralTypeCase match { case proto.Expression.Literal.LiteralTypeCase.STRUCT => val s = literal.getStruct - val schema = DataTypeProtoConverter.toCatalystType(s.getStructType) - if (schema == VectorUDT.sqlType) { - deserializeVector(s) - } else if (schema == MatrixUDT.sqlType) { - deserializeMatrix(s) - } else { - throw MlUnsupportedException(s"Unsupported parameter struct ${schema} for ${name}") + s.getStructType.getUdt.getJvmClass match { + case "org.apache.spark.ml.linalg.VectorUDT" => deserializeVector(s) + case "org.apache.spark.ml.linalg.MatrixUDT" => deserializeMatrix(s) + case _ => + throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct} for ${name}") } case _ => diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala index ee0812a1a98ca..df3e97398012b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala @@ -21,7 +21,7 @@ import org.apache.spark.connect.proto import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.Params import org.apache.spark.sql.Dataset -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter, ProtoDataTypes} +import org.apache.spark.sql.connect.common.{LiteralValueProtoConverter, ProtoDataTypes} import org.apache.spark.sql.connect.service.SessionHolder private[ml] object Serializer { @@ -37,7 +37,7 @@ private[ml] object Serializer { data match { case v: SparseVector => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType)) + builder.setStructType(ProtoDataTypes.VectorUDT) // type = 0 builder.addElements(proto.Expression.Literal.newBuilder().setByte(0)) // size @@ -50,7 +50,7 @@ private[ml] object Serializer { case v: DenseVector => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType)) + builder.setStructType(ProtoDataTypes.VectorUDT) // type = 1 builder.addElements(proto.Expression.Literal.newBuilder().setByte(1)) // size = null @@ -65,7 +65,7 @@ private[ml] object Serializer { case m: SparseMatrix => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType)) + builder.setStructType(ProtoDataTypes.MatrixUDT) // type = 0 builder.addElements(proto.Expression.Literal.newBuilder().setByte(0)) // numRows @@ -84,7 +84,7 @@ private[ml] object Serializer { case m: DenseMatrix => val builder = proto.Expression.Literal.Struct.newBuilder() - builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType)) + builder.setStructType(ProtoDataTypes.MatrixUDT) // type = 1 builder.addElements(proto.Expression.Literal.newBuilder().setByte(1)) // numRows @@ -146,13 +146,13 @@ private[ml] object Serializer { literal.getLiteralTypeCase match { case proto.Expression.Literal.LiteralTypeCase.STRUCT => val struct = literal.getStruct - val schema = DataTypeProtoConverter.toCatalystType(struct.getStructType) - if (schema == VectorUDT.sqlType) { - (MLUtils.deserializeVector(struct), classOf[Vector]) - } else if (schema == MatrixUDT.sqlType) { - (MLUtils.deserializeMatrix(struct), classOf[Matrix]) - } else { - throw MlUnsupportedException(s"$schema not supported") + struct.getStructType.getUdt.getJvmClass match { + case "org.apache.spark.ml.linalg.VectorUDT" => + (MLUtils.deserializeVector(struct), classOf[Vector]) + case "org.apache.spark.ml.linalg.MatrixUDT" => + (MLUtils.deserializeMatrix(struct), classOf[Matrix]) + case _ => + throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct}") } case proto.Expression.Literal.LiteralTypeCase.INTEGER => (literal.getInteger.asInstanceOf[Object], classOf[Int])