Skip to content

Commit

Permalink
Support for tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
nob13 committed Jan 18, 2024
1 parent f9b3c96 commit 253c6c2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class CassandraDataFrameSelectUdtSpec extends SparkCassandraITFlatSpecBase with
| single embedded,
| embedded_map MAP<INT, FROZEN<embedded>>,
| embedded_set SET<FROZEN<embedded>>,
| simple_tuple TUPLE <TEXT, INT>,
| simple_tuples LIST<FROZEN<TUPLE<TEXT, INT>>>,
| PRIMARY KEY (id)
| )""".stripMargin
)

session.execute(
s"""INSERT INTO ${ks}.crash_test JSON '{"id": 1, "embeddeds": [], "embedded_map": {}, "embedded_set": []}'"""
s"""INSERT INTO ${ks}.crash_test JSON '{"id": 1, "embeddeds": [], "embedded_map": {}, "embedded_set": [], "simple_tuples": []}'"""
)
session.execute(
s"""INSERT INTO ${ks}.crash_test JSON
Expand All @@ -42,7 +44,9 @@ class CassandraDataFrameSelectUdtSpec extends SparkCassandraITFlatSpecBase with
| "single": {"a": "a1", "b": 1},
| "embeddeds": [{"a": "x1", "b": 1}, {"a": "x2", "b": 2}],
| "embedded_map": {"1": {"a": "x1", "b": 1}, "2": {"a": "x2", "b": 2}},
| "embedded_set": [{"a": "x1", "b": 1}, {"a": "x2", "b": 2}]
| "embedded_set": [{"a": "x1", "b": 1}, {"a": "x2", "b": 2}],
| "simple_tuple": ["x1", 1],
| "simple_tuples": [["x1", 1], ["x2", 2]]
|}'
|""".stripMargin
)
Expand Down Expand Up @@ -95,6 +99,19 @@ class CassandraDataFrameSelectUdtSpec extends SparkCassandraITFlatSpecBase with
elements should contain theSameElementsAs Seq(0, 3)
}

it should "allow single elements of tuples" in new Env {
df.select("simple_tuple.1").collect().map { row =>
optionalInt(row, 0).getOrElse(0)
} should contain theSameElementsAs Seq(0, 1)
}

it should "allow selecting projections of tuples" in new Env {
val elements = df.select("simple_tuples.1").collect().map { row =>
row.getAs[Seq[Int]](0).sum
}
elements should contain theSameElementsAs Seq(0, 3)
}

private def optionalInt(row: Row, idx: Int): Option[Int] = {
if (row.isNullAt(idx)) {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.datastax.spark.connector.datasource
import com.datastax.oss.driver.api.core.cql.Row
import com.datastax.spark.connector.cql.TableDef
import com.datastax.spark.connector.rdd.reader.{RowReader, RowReaderFactory}
import com.datastax.spark.connector.{CassandraRow, CassandraRowMetadata, ColumnRef, UDTValue}
import com.datastax.spark.connector.{CassandraRow, CassandraRowMetadata, ColumnRef, TupleValue, UDTValue}
import org.apache.spark.sql.cassandra.CassandraSQLRow.toUnsafeSqlType
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
Expand Down Expand Up @@ -97,6 +97,18 @@ object UdtProjectionDecoder {
decoded.asInstanceOf[AnyRef]
}
UDTValue.apply(structType.fieldNames.toIndexedSeq, selectedValues.toIndexedSeq)
case tuple: TupleValue =>
val selectedValues = structType.fields.zipWithIndex.map { case (field, idx) =>
val fieldInt = try {
field.name.toInt
} catch {
case _: NumberFormatException =>
throw new IllegalArgumentException(s"Expected integer for tuple column name, got ${field.name}")
}
val decoded = childDecoders(idx)(tuple.values(fieldInt))
decoded.asInstanceOf[AnyRef]
}
TupleValue(selectedValues.toIndexedSeq: _*)
case other =>
// ??
other
Expand Down

0 comments on commit 253c6c2

Please sign in to comment.