Skip to content

Commit

Permalink
SNOW-903045 Support Dataframe Contains both Binary and Variant Columns (
Browse files Browse the repository at this point in the history
#526)

* support dataframe contains both binary and variant columns

* fix error

* fix error
  • Loading branch information
sfc-gh-bli authored Sep 18, 2023
1 parent fedb7a3 commit 4cb486c
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 6 deletions.
54 changes: 54 additions & 0 deletions src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class VariantTypeSuite extends IntegrationSuiteBase {
val tableName1 = s"spark_test_table_1$randomSuffix"
val tableName2 = s"spark_test_table_2$randomSuffix"
val tableName3 = s"spark_test_table_3$randomSuffix"
val tableName4 = s"spark_test_table_4$randomSuffix"
override def beforeAll(): Unit = {
super.beforeAll()

Expand All @@ -43,6 +44,7 @@ class VariantTypeSuite extends IntegrationSuiteBase {
jdbcUpdate(s"drop table if exists $tableName1")
jdbcUpdate(s"drop table if exists $tableName2")
jdbcUpdate(s"drop table if exists $tableName3")
jdbcUpdate(s"drop table if exists $tableName4")
super.afterAll()
}

Expand Down Expand Up @@ -211,4 +213,56 @@ class VariantTypeSuite extends IntegrationSuiteBase {
)
}

test ("load variant + binary column") {
// COPY UNLOAD can't be run because it doesn't support binary
if (!params.useCopyUnload) {
val data = sc.parallelize(
Seq(
Row("binary1".getBytes(), Array(1, 2, 3), Map("a" -> 1), Row("abc")),
Row("binary2".getBytes(), Array(4, 5, 6), Map("b" -> 2), Row("def")),
Row("binary3".getBytes(), Array(7, 8, 9), Map("c" -> 3), Row("ghi"))
)
)

val schema1 = new StructType(
Array(
StructField("BIN", BinaryType, nullable = false),
StructField("ARR", ArrayType(IntegerType), nullable = false),
StructField("MAP", MapType(StringType, IntegerType), nullable = false),
StructField(
"OBJ",
StructType(Array(StructField("STR", StringType, nullable = false)))
)
)
)

val df = sparkSession.createDataFrame(data, schema1)

df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", tableName4)
.mode(SaveMode.Overwrite)
.save()

val out = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", tableName4)
.schema(schema1)
.load()

val result = out.collect()
assert(result.length == 3)

val bin = result(0).get(0).asInstanceOf[Array[Byte]]
assert(new String(bin).equals("binary1"))
assert(result(0).getList[Int](1).get(0) == 1)
assert(result(1).getList[Int](1).get(1) == 5)
assert(result(2).getList[Int](1).get(2) == 9)
assert(result(1).getMap[String, Int](2)("b") == 2)
assert(result(2).getStruct(3).getString(0) == "ghi")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,44 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
.mkString("|")
})
case SupportedFormat.JSON =>
data.toJSON.map(_.toString).rdd
// convert binary (Array of Byte) to encoded base64 String before COPY
val newSchema: StructType = prepareSchemaForJson(data.schema)
val conversionsFunction = genConversionFunctionsForJson(data.schema)
val newData: RDD[Row] = data.rdd.map(row => {
Row.fromSeq(
row.toSeq
.zip(conversionsFunction)
.map {
case (element, func) => func(element)
}
)
})
spark.createDataFrame(newData, newSchema).toJSON.map(_.toString).rdd
}
}

private def prepareSchemaForJson(schema: StructType): StructType =
StructType.apply(schema.map{
// Binary types will be converted to String type before COPY
case field: StructField if field.dataType == BinaryType =>
StructField(field.name, StringType, field.nullable, field.metadata)
case other => other
})


private def genConversionFunctionsForJson(schema: StructType): Array[Any => Any] =
schema.fields.map(field =>
field.dataType match {
case BinaryType =>
(v: Any) =>
v match {
case null => ""
case bytes: Array[Byte] => Base64.encodeBase64String(bytes)
}
case _ => (input: Any) => input
}
)

/**
* If column mapping is enable, remove all useless columns from the input DataFrame
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ package net.snowflake.spark.snowflake.io
import java.sql.ResultSet
import java.time.LocalDateTime
import java.util.TimeZone

import net.snowflake.client.jdbc.SnowflakeResultSet
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake._
import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat
import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations
import net.snowflake.spark.snowflake.test.{TestHook, TestHookFlag}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{BinaryType, StructType}
import org.apache.spark.sql.{SQLContext, SaveMode}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -836,13 +835,16 @@ private[io] object StageWriter {
val columnPrefix = if (params.useParseJsonForWrite) "parse_json($1):" else "$1:"
if (list.isEmpty || list.get.isEmpty) {
val names = schema.fields
.map(x => columnPrefix.concat(
if (params.quoteJsonFieldName) {
.map(x => {
var name = if (params.quoteJsonFieldName) {
"\"" + x.name + "\""
} else {
x.name
}
))
if (x.dataType == BinaryType) {
name += "::BINARY"
}
columnPrefix.concat(name)})
.mkString(",")
ConstantString("from (select") + names + from + "tmp)"
} else {
Expand Down Expand Up @@ -908,6 +910,7 @@ private[io] object StageWriter {
ConstantString(s"""
|FILE_FORMAT = (
| TYPE = JSON
| BINARY_FORMAT=BASE64
|)
""".stripMargin) !
}
Expand Down

0 comments on commit 4cb486c

Please sign in to comment.