Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Parquet Issues #591

Merged
merged 8 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class ParquetSuite extends IntegrationSuiteBase {
// throw exception because only support SaveMode.Append
assertThrows[UnsupportedOperationException] {
df.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_parquet)
Expand All @@ -514,7 +514,7 @@ class ParquetSuite extends IntegrationSuiteBase {
// throw exception because "aaa" is not a column name of DF
assertThrows[IllegalArgumentException] {
df.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_parquet)
Expand All @@ -526,7 +526,7 @@ class ParquetSuite extends IntegrationSuiteBase {
// throw exception because "AAA" is not a column name of table in snowflake database
assertThrows[IllegalArgumentException] {
df.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_parquet)
Expand All @@ -536,6 +536,41 @@ class ParquetSuite extends IntegrationSuiteBase {
}
}

test("null value in array") {
val data: RDD[Row] = sc.makeRDD(
List(
Row(
Array(null, "one", "two", "three"),
),
Row(
Array("one", null, "two", "three"),
)
)
)

val schema = StructType(List(
StructField("ARRAY_STRING_FIELD",
ArrayType(StringType, containsNull = true), nullable = true)))
val df = sparkSession.createDataFrame(data, schema)
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_array_map)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.mode(SaveMode.Overwrite)
.save()


val res = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_array_map)
.schema(schema)
.load().collect()
assert(res.head.getSeq(0) == Seq("null", "one", "two", "three"))
assert(res(1).getSeq(0) == Seq("one", "null", "two", "three"))
}

test("test error when column map does not match") {
jdbcUpdate(s"create or replace table $test_column_map_not_match (num int, str string)")
// auto map
Expand All @@ -547,7 +582,7 @@ class ParquetSuite extends IntegrationSuiteBase {

assertThrows[SQLException]{
df1.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_not_match)
Expand Down
15 changes: 3 additions & 12 deletions src/it/scala/net/snowflake/spark/snowflake/io/StageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,13 @@ class StageSuite extends IntegrationSuiteBase {
try {
// The credential for the external stage is fake.
val azureExternalStage = ExternalAzureStorage(
param,
containerName = "test_fake_container",
azureAccount = "test_fake_account",
azureEndpoint = "blob.core.windows.net",
azureSAS =
"?sig=test_test_test_test_test_test_test_test_test_test_test_test" +
"_test_test_test_test_test_fak&spr=https&sp=rwdl&sr=c",
param.proxyInfo,
param.maxRetryCount,
param.sfURL,
param.useExponentialBackoff,
param.expectedPartitionCount,
pref = "test_dir",
connection = connection
Expand Down Expand Up @@ -367,13 +364,10 @@ class StageSuite extends IntegrationSuiteBase {
try {
// The credential for the external stage is fake.
val s3ExternalStage = ExternalS3Storage(
param,
bucketName = "test_fake_bucket",
awsId = "TEST_TEST_TEST_TEST1",
awsKey = "TEST_TEST_TEST_TEST_TEST_TEST_TEST_TEST2",
param.proxyInfo,
param.maxRetryCount,
param.sfURL,
param.useExponentialBackoff,
param.expectedPartitionCount,
pref = "test_dir",
connection = connection,
Expand Down Expand Up @@ -487,13 +481,10 @@ class StageSuite extends IntegrationSuiteBase {
try {
// The credential for the external stage is fake.
val s3ExternalStage = ExternalS3Storage(
param,
bucketName = "test_fake_bucket",
awsId = "TEST_TEST_TEST_TEST1",
awsKey = "TEST_TEST_TEST_TEST_TEST_TEST_TEST_TEST2",
param.proxyInfo,
param.maxRetryCount,
param.sfURL,
param.useExponentialBackoff,
param.expectedPartitionCount,
pref = "test_dir",
connection = connection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object Parameters {
Set("off", "no", "false", "0", "disabled")

// enable parquet format
val PARAM_USE_PARQUET_IN_WRITE: String = knownParam("use_parquet_in_write ")
val PARAM_USE_PARQUET_IN_WRITE: String = knownParam("use_parquet_in_write")

/**
* Helper method to check if a given string represents some form
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,15 @@

package net.snowflake.spark.snowflake

import scala.collection.JavaConverters._
import java.sql.{Date, Timestamp}
import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64
import net.snowflake.spark.snowflake.DefaultJDBCWrapper.{snowflakeStyleSchema, snowflakeStyleString}
import net.snowflake.spark.snowflake.Parameters.{MergedParameters, mergeParameters}
import net.snowflake.spark.snowflake.SparkConnectorContext.getClass
import net.snowflake.spark.snowflake.Utils.ensureUnquoted
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake.io.SupportedFormat
import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat
import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.slf4j.LoggerFactory

import java.nio.ByteBuffer
import java.time.{LocalDate, ZoneId, ZoneOffset}
import java.util.concurrent.TimeUnit
import scala.collection.mutable

/**
* Functions to write data to Snowflake.
Expand Down Expand Up @@ -198,42 +186,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
format match {
case SupportedFormat.PARQUET =>
val snowflakeStyleSchema = mapColumn(data.schema, params, snowflakeStyle = true)
val schema = io.ParquetUtils.convertStructToAvro(snowflakeStyleSchema)
(data.rdd.map (row => {
def rowToAvroRecord(row: Row,
schema: Schema,
snowflakeStyleSchema: StructType,
params: MergedParameters): GenericData.Record = {
val record = new GenericData.Record(schema)
row.toSeq.zip(snowflakeStyleSchema.names).foreach {
case (row: Row, name) =>
record.put(name,
rowToAvroRecord(
row,
schema.getField(name).schema().getTypes.get(0),
snowflakeStyleSchema(name).dataType.asInstanceOf[StructType],
params
))
case (map: scala.collection.immutable.Map[Any, Any], name) =>
record.put(name, map.asJava)
case (str: String, name) =>
record.put(name, if (params.trimSpace) str.trim else str)
case (arr: mutable.WrappedArray[Any], name) =>
record.put(name, arr.toArray)
case (decimal: java.math.BigDecimal, name) =>
record.put(name, ByteBuffer.wrap(decimal.unscaledValue().toByteArray))
case (timestamp: java.sql.Timestamp, name) =>
record.put(name, timestamp.toString)
case (date: java.sql.Date, name) =>
record.put(name, date.toString)
case (date: java.time.LocalDateTime, name) =>
record.put(name, date.toEpochSecond(ZoneOffset.UTC))
case (value, name) => record.put(name, value)
}
record
}
rowToAvroRecord(row, schema, snowflakeStyleSchema, params)
}), snowflakeStyleSchema)
(data.rdd.asInstanceOf[RDD[Any]], snowflakeStyleSchema)
case SupportedFormat.CSV =>
val conversionFunction = genConversionFunctions(data.schema, params)
(data.rdd.map(row => {
Expand Down
Loading
Loading