From 253c6c267184b9ce6df9ec46205d7317f6fe3af2 Mon Sep 17 00:00:00 2001 From: Norbert Schultz Date: Thu, 18 Jan 2024 14:09:44 +0100 Subject: [PATCH] Support for tuples --- .../sql/CassandraDataFrameSelectUdtSpec.scala | 21 +++++++++++++++++-- .../datasource/UnsafeRowReaderFactory.scala | 14 ++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/connector/src/it/scala/com/datastax/spark/connector/sql/CassandraDataFrameSelectUdtSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/sql/CassandraDataFrameSelectUdtSpec.scala index f5698d092..a7162bf8e 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/sql/CassandraDataFrameSelectUdtSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/sql/CassandraDataFrameSelectUdtSpec.scala @@ -28,12 +28,14 @@ class CassandraDataFrameSelectUdtSpec extends SparkCassandraITFlatSpecBase with | single embedded, | embedded_map MAP>, | embedded_set SET>, + | simple_tuple TUPLE , + | simple_tuples LIST>>, | PRIMARY KEY (id) | )""".stripMargin ) session.execute( - s"""INSERT INTO ${ks}.crash_test JSON '{"id": 1, "embeddeds": [], "embedded_map": {}, "embedded_set": []}'""" + s"""INSERT INTO ${ks}.crash_test JSON '{"id": 1, "embeddeds": [], "embedded_map": {}, "embedded_set": [], "simple_tuples": []}'""" ) session.execute( s"""INSERT INTO ${ks}.crash_test JSON @@ -42,7 +44,9 @@ class CassandraDataFrameSelectUdtSpec extends SparkCassandraITFlatSpecBase with | "single": {"a": "a1", "b": 1}, | "embeddeds": [{"a": "x1", "b": 1}, {"a": "x2", "b": 2}], | "embedded_map": {"1": {"a": "x1", "b": 1}, "2": {"a": "x2", "b": 2}}, - | "embedded_set": [{"a": "x1", "b": 1}, {"a": "x2", "b": 2}] + | "embedded_set": [{"a": "x1", "b": 1}, {"a": "x2", "b": 2}], + | "simple_tuple": ["x1", 1], + | "simple_tuples": [["x1", 1], ["x2", 2]] |}' |""".stripMargin ) @@ -95,6 +99,19 @@ class CassandraDataFrameSelectUdtSpec extends SparkCassandraITFlatSpecBase with elements should contain theSameElementsAs Seq(0, 3) } + it should "allow single elements of tuples" in new Env { + df.select("simple_tuple.1").collect().map { row => + optionalInt(row, 0).getOrElse(0) + } should contain theSameElementsAs Seq(0, 1) + } + + it should "allow selecting projections of tuples" in new Env { + val elements = df.select("simple_tuples.1").collect().map { row => + row.getAs[Seq[Int]](0).sum + } + elements should contain theSameElementsAs Seq(0, 3) + } + private def optionalInt(row: Row, idx: Int): Option[Int] = { if (row.isNullAt(idx)) { None diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/UnsafeRowReaderFactory.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/UnsafeRowReaderFactory.scala index cc33f5c67..dbfb6a069 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/UnsafeRowReaderFactory.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/UnsafeRowReaderFactory.scala @@ -3,7 +3,7 @@ package com.datastax.spark.connector.datasource import com.datastax.oss.driver.api.core.cql.Row import com.datastax.spark.connector.cql.TableDef import com.datastax.spark.connector.rdd.reader.{RowReader, RowReaderFactory} -import com.datastax.spark.connector.{CassandraRow, CassandraRowMetadata, ColumnRef, UDTValue} +import com.datastax.spark.connector.{CassandraRow, CassandraRowMetadata, ColumnRef, TupleValue, UDTValue} import org.apache.spark.sql.cassandra.CassandraSQLRow.toUnsafeSqlType import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -97,6 +97,18 @@ object UdtProjectionDecoder { decoded.asInstanceOf[AnyRef] } UDTValue.apply(structType.fieldNames.toIndexedSeq, selectedValues.toIndexedSeq) + case tuple: TupleValue => + val selectedValues = structType.fields.zipWithIndex.map { case (field, idx) => + val fieldInt = try { + field.name.toInt + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"Expected integer for tuple column name, got ${field.name}") + } + val decoded = childDecoders(idx)(tuple.values(fieldInt)) + decoded.asInstanceOf[AnyRef] + } + TupleValue(selectedValues.toIndexedSeq: _*) case other => // ?? other