Skip to content

Commit

Permalink
SNOW-1045973 Introduce trim_space Parameter (#545)
Browse files Browse the repository at this point in the history
* add new trim_space parameter

* fix

* add test

* fix test
  • Loading branch information
sfc-gh-bli authored Feb 16, 2024
1 parent 5fc3a1e commit a716144
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 9 deletions.
116 changes: 116 additions & 0 deletions src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,122 @@ class IssueSuite extends IntegrationSuiteBase {
super.beforeEach()
}

test("trim space - csv") {
val st1 = new StructType(
Array(StructField("str", StringType, nullable = false))
)
val tt: String = s"tt_$randomSuffix"
try {
val df = sparkSession
.createDataFrame(
sparkSession.sparkContext.parallelize(
Seq(
Row("ab c"),
Row(" a bc"),
Row("abdc ")
)
),
st1
)
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.option(Parameters.PARAM_TRIM_SPACE, "true")
.mode(SaveMode.Overwrite)
.save()

var loadDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.load()

assert(loadDf.collect().forall(row => row.toSeq.head.toString.length == 4))

// disabled by default
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.mode(SaveMode.Overwrite)
.save()

loadDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.load()
val result = loadDf.collect()
assert(result.head.toSeq.head.toString.length == 4)
assert(result(1).toSeq.head.toString.length == 5)
assert(result(2).toSeq.head.toString.length == 6)


} finally {
jdbcUpdate(s"drop table if exists $tt")
}
}

test("trim space - json") {
val st1 = new StructType(
Array(
StructField("str", StringType, nullable = false),
StructField("arr", ArrayType(IntegerType), nullable = false)
)
)
val tt: String = s"tt_$randomSuffix"
try {
val df = sparkSession
.createDataFrame(
sparkSession.sparkContext.parallelize(
Seq(
Row("ab c", Array(1, 2, 3)),
Row(" a bc", Array(2, 2, 3)),
Row("abdc ", Array(3, 2, 3))
)
),
st1
)
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.option(Parameters.PARAM_TRIM_SPACE, "true")
.mode(SaveMode.Overwrite)
.save()

var loadDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.load()

assert(loadDf.select("str").collect().forall(row => row.toSeq.head.toString.length == 4))

// disabled by default
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.mode(SaveMode.Overwrite)
.save()

loadDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.load()
val result = loadDf.select("str").collect()
assert(result.head.toSeq.head.toString.length == 4)
assert(result(1).toSeq.head.toString.length == 5)
assert(result(2).toSeq.head.toString.length == 6)

} finally {
jdbcUpdate(s"drop table if exists $tt")
}
}

test("csv delimiter character should not break rows") {
val st1 = new StructType(
Array(
Expand Down
11 changes: 10 additions & 1 deletion src/main/scala/net/snowflake/spark/snowflake/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ object Parameters {
val PARAM_COLUMN_MAP: String = knownParam("columnmap")
val PARAM_TRUNCATE_COLUMNS: String = knownParam("truncate_columns")
val PARAM_PURGE: String = knownParam("purge")
val PARAM_TRIM_SPACE: String = knownParam("trim_space")

val PARAM_TRUNCATE_TABLE: String = knownParam("truncate_table")
val PARAM_CONTINUE_ON_ERROR: String = knownParam("continue_on_error")
Expand Down Expand Up @@ -288,7 +289,8 @@ object Parameters {
PARAM_USE_AWS_MULTIPLE_PARTS_UPLOAD -> "true",
PARAM_TIMESTAMP_NTZ_OUTPUT_FORMAT -> "YYYY-MM-DD HH24:MI:SS.FF3",
PARAM_TIMESTAMP_LTZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3",
PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3"
PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3",
PARAM_TRIM_SPACE -> "false"
)

/**
Expand Down Expand Up @@ -837,6 +839,13 @@ object Parameters {
def useStagingTable: Boolean =
isTrue(parameters(PARAM_USE_STAGING_TABLE))

/**
* Boolean that specifies whether to remove white space from String fields
* Defaults to false
*/
def trimSpace: Boolean =
isTrue(parameters(PARAM_TRIM_SPACE))

/**
* Extra options to append to the Snowflake COPY command (e.g. "MAXERROR 100").
*/
Expand Down
24 changes: 17 additions & 7 deletions src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
package net.snowflake.spark.snowflake

import java.sql.{Date, Timestamp}

import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake.Parameters.{MergedParameters, mergeParameters}
import net.snowflake.spark.snowflake.io.SupportedFormat
import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -83,7 +82,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {

format match {
case SupportedFormat.CSV =>
val conversionFunction = genConversionFunctions(data.schema)
val conversionFunction = genConversionFunctions(data.schema, params)
data.rdd.map(row => {
row.toSeq
.zip(conversionFunction)
Expand All @@ -95,7 +94,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
case SupportedFormat.JSON =>
// convert binary (Array of Byte) to encoded base64 String before COPY
val newSchema: StructType = prepareSchemaForJson(data.schema)
val conversionsFunction = genConversionFunctionsForJson(data.schema)
val conversionsFunction = genConversionFunctionsForJson(data.schema, params)
val newData: RDD[Row] = data.rdd.map(row => {
Row.fromSeq(
row.toSeq
Expand All @@ -118,9 +117,15 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
})


private def genConversionFunctionsForJson(schema: StructType): Array[Any => Any] =
private def genConversionFunctionsForJson(schema: StructType,
params: MergedParameters): Array[Any => Any] =
schema.fields.map(field =>
field.dataType match {
case StringType =>
(v: Any) =>
if (params.trimSpace) {
v.toString.trim
} else v
case BinaryType =>
(v: Any) =>
v match {
Expand Down Expand Up @@ -157,7 +162,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
}

// Prepare a set of conversion functions, based on the schema
def genConversionFunctions(schema: StructType): Array[Any => Any] =
def genConversionFunctions(schema: StructType, params: MergedParameters): Array[Any => Any] =
schema.fields.map { field =>
field.dataType match {
case DateType =>
Expand All @@ -177,7 +182,12 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
(v: Any) =>
{
if (v == null) ""
else Conversions.formatString(v.asInstanceOf[String])
else {
val trimmed = if (params.trimSpace) {
v.toString.trim
} else v
Conversions.formatString(trimmed.asInstanceOf[String])
}
}
case BinaryType =>
(v: Any) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,6 @@ private[io] object StageWriter {
params.getStringTimestampFormat.get
}


val formatString =
format match {
case SupportedFormat.CSV =>
Expand Down

0 comments on commit a716144

Please sign in to comment.