From a716144d3d556771efdc2bf791b5d3b538a80cae Mon Sep 17 00:00:00 2001 From: Bing Li <63471091+sfc-gh-bli@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:19:45 -0800 Subject: [PATCH] SNOW-1045973 Introduce trim_space Parameter (#545) * add new trim_space parameter * fix * add test * fix test --- .../spark/snowflake/IssueSuite.scala | 116 ++++++++++++++++++ .../spark/snowflake/Parameters.scala | 11 +- .../spark/snowflake/SnowflakeWriter.scala | 24 ++-- .../spark/snowflake/io/StageWriter.scala | 1 - 4 files changed, 143 insertions(+), 9 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala index 1be7e111..97edd56c 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala @@ -20,6 +20,122 @@ class IssueSuite extends IntegrationSuiteBase { super.beforeEach() } + test("trim space - csv") { + val st1 = new StructType( + Array(StructField("str", StringType, nullable = false)) + ) + val tt: String = s"tt_$randomSuffix" + try { + val df = sparkSession + .createDataFrame( + sparkSession.sparkContext.parallelize( + Seq( + Row("ab c"), + Row(" a bc"), + Row("abdc ") + ) + ), + st1 + ) + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .option(Parameters.PARAM_TRIM_SPACE, "true") + .mode(SaveMode.Overwrite) + .save() + + var loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + + assert(loadDf.collect().forall(row => row.toSeq.head.toString.length == 4)) + + // disabled by default + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .mode(SaveMode.Overwrite) + .save() + + loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + val result = loadDf.collect() + assert(result.head.toSeq.head.toString.length == 4) + assert(result(1).toSeq.head.toString.length == 5) + assert(result(2).toSeq.head.toString.length == 6) + + + } finally { + jdbcUpdate(s"drop table if exists $tt") + } + } + + test("trim space - json") { + val st1 = new StructType( + Array( + StructField("str", StringType, nullable = false), + StructField("arr", ArrayType(IntegerType), nullable = false) + ) + ) + val tt: String = s"tt_$randomSuffix" + try { + val df = sparkSession + .createDataFrame( + sparkSession.sparkContext.parallelize( + Seq( + Row("ab c", Array(1, 2, 3)), + Row(" a bc", Array(2, 2, 3)), + Row("abdc ", Array(3, 2, 3)) + ) + ), + st1 + ) + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .option(Parameters.PARAM_TRIM_SPACE, "true") + .mode(SaveMode.Overwrite) + .save() + + var loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + + assert(loadDf.select("str").collect().forall(row => row.toSeq.head.toString.length == 4)) + + // disabled by default + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .mode(SaveMode.Overwrite) + .save() + + loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + val result = loadDf.select("str").collect() + assert(result.head.toSeq.head.toString.length == 4) + assert(result(1).toSeq.head.toString.length == 5) + assert(result(2).toSeq.head.toString.length == 6) + + } finally { + jdbcUpdate(s"drop table if exists $tt") + } + } + test("csv delimiter character should not break rows") { val st1 = new StructType( Array( diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index fcf2e19e..3d7f1fd9 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -85,6 +85,7 @@ object Parameters { val PARAM_COLUMN_MAP: String = knownParam("columnmap") val PARAM_TRUNCATE_COLUMNS: String = knownParam("truncate_columns") val PARAM_PURGE: String = knownParam("purge") + val PARAM_TRIM_SPACE: String = knownParam("trim_space") val PARAM_TRUNCATE_TABLE: String = knownParam("truncate_table") val PARAM_CONTINUE_ON_ERROR: String = knownParam("continue_on_error") @@ -288,7 +289,8 @@ object Parameters { PARAM_USE_AWS_MULTIPLE_PARTS_UPLOAD -> "true", PARAM_TIMESTAMP_NTZ_OUTPUT_FORMAT -> "YYYY-MM-DD HH24:MI:SS.FF3", PARAM_TIMESTAMP_LTZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3", - PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3" + PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3", + PARAM_TRIM_SPACE -> "false" ) /** @@ -837,6 +839,13 @@ object Parameters { def useStagingTable: Boolean = isTrue(parameters(PARAM_USE_STAGING_TABLE)) + /** + * Boolean that specifies whether to remove white space from String fields + * Defaults to false + */ + def trimSpace: Boolean = + isTrue(parameters(PARAM_TRIM_SPACE)) + /** * Extra options to append to the Snowflake COPY command (e.g. "MAXERROR 100"). */ diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala index e35c8bf2..52c47472 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala @@ -18,9 +18,8 @@ package net.snowflake.spark.snowflake import java.sql.{Date, Timestamp} - import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64 -import net.snowflake.spark.snowflake.Parameters.MergedParameters +import net.snowflake.spark.snowflake.Parameters.{MergedParameters, mergeParameters} import net.snowflake.spark.snowflake.io.SupportedFormat import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat import org.apache.spark.rdd.RDD @@ -83,7 +82,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { format match { case SupportedFormat.CSV => - val conversionFunction = genConversionFunctions(data.schema) + val conversionFunction = genConversionFunctions(data.schema, params) data.rdd.map(row => { row.toSeq .zip(conversionFunction) @@ -95,7 +94,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { case SupportedFormat.JSON => // convert binary (Array of Byte) to encoded base64 String before COPY val newSchema: StructType = prepareSchemaForJson(data.schema) - val conversionsFunction = genConversionFunctionsForJson(data.schema) + val conversionsFunction = genConversionFunctionsForJson(data.schema, params) val newData: RDD[Row] = data.rdd.map(row => { Row.fromSeq( row.toSeq @@ -118,9 +117,15 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { }) - private def genConversionFunctionsForJson(schema: StructType): Array[Any => Any] = + private def genConversionFunctionsForJson(schema: StructType, + params: MergedParameters): Array[Any => Any] = schema.fields.map(field => field.dataType match { + case StringType => + (v: Any) => + if (params.trimSpace) { + v.toString.trim + } else v case BinaryType => (v: Any) => v match { @@ -157,7 +162,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { } // Prepare a set of conversion functions, based on the schema - def genConversionFunctions(schema: StructType): Array[Any => Any] = + def genConversionFunctions(schema: StructType, params: MergedParameters): Array[Any => Any] = schema.fields.map { field => field.dataType match { case DateType => @@ -177,7 +182,12 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { (v: Any) => { if (v == null) "" - else Conversions.formatString(v.asInstanceOf[String]) + else { + val trimmed = if (params.trimSpace) { + v.toString.trim + } else v + Conversions.formatString(trimmed.asInstanceOf[String]) + } } case BinaryType => (v: Any) => diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index fbd19b22..aa5a4795 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -908,7 +908,6 @@ private[io] object StageWriter { params.getStringTimestampFormat.get } - val formatString = format match { case SupportedFormat.CSV =>