Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yuwang committed Oct 1, 2024
1 parent 2f7fff8 commit 3405888
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 71 deletions.
97 changes: 56 additions & 41 deletions src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,32 @@ import scala.collection.Seq
import scala.util.Random

class ParquetSuite extends IntegrationSuiteBase {
val test_parquet_table: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_parquet_column_map: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_special_character: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val dbtable1: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_all_type: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_all_type_multi_line: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_array_map: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_conversion: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_conversion_by_name: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_trim: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_date: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_special_char: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_special_char_to_exist: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map_parquet: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map_not_match: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString

override def afterAll(): Unit = {
runSql(s"drop table if exists $test_parquet_table")
runSql(s"drop table if exists $test_parquet_column_map")
runSql(s"drop table if exists $test_special_character")
runSql(s"drop table if exists $dbtable1")
runSql(s"drop table if exists $test_all_type")
runSql(s"drop table if exists $test_all_type_multi_line")
runSql(s"drop table if exists $test_array_map")
runSql(s"drop table if exists $test_conversion")
runSql(s"drop table if exists $test_conversion_by_name")
runSql(s"drop table if exists $test_column_map")
runSql(s"drop table if exists $test_trim")
runSql(s"drop table if exists $test_date")
runSql(s"drop table if exists $test_special_char")
runSql(s"drop table if exists $test_special_char_to_exist")
runSql(s"drop table if exists $test_column_map_parquet")
runSql(s"drop table if exists $test_column_map_not_match")
super.afterAll()
}

Expand Down Expand Up @@ -55,15 +71,15 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_parquet_table)
.option("dbtable", test_all_type)
.mode(SaveMode.Overwrite)
.save()


val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_all_type)
.load()

val expectedAnswer = List(
Expand Down Expand Up @@ -105,7 +121,7 @@ class ParquetSuite extends IntegrationSuiteBase {
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_all_type_multi_line)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.mode(SaveMode.Overwrite)
.save()
Expand All @@ -114,7 +130,7 @@ class ParquetSuite extends IntegrationSuiteBase {
val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_all_type_multi_line)
.load()

val expectedAnswer = List(
Expand Down Expand Up @@ -157,7 +173,7 @@ class ParquetSuite extends IntegrationSuiteBase {
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_array_map)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.mode(SaveMode.Overwrite)
.save()
Expand All @@ -166,7 +182,7 @@ class ParquetSuite extends IntegrationSuiteBase {
val res = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_array_map)
.schema(schema)
.load()
.collect()
Expand All @@ -192,23 +208,23 @@ class ParquetSuite extends IntegrationSuiteBase {
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_conversion)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.mode(SaveMode.Overwrite)
.save()

val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_conversion)
.load()

checkAnswer(newDf, List(Row(1, 2, 3)))
}

test("test parquet name conversion with column map by name"){
jdbcUpdate(
s"""create or replace table $test_parquet_column_map
s"""create or replace table $test_conversion_by_name
|(ONE int, TWO int, THREE int, "Fo.ur" int)""".stripMargin
)

Expand All @@ -226,7 +242,7 @@ class ParquetSuite extends IntegrationSuiteBase {
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_column_map)
.option("dbtable", test_conversion_by_name)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("column_mapping", "name")
.mode(SaveMode.Append)
Expand All @@ -235,7 +251,7 @@ class ParquetSuite extends IntegrationSuiteBase {
val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_column_map)
.option("dbtable", test_conversion_by_name)
.load()

checkAnswer(newDf, List(Row(2, 1, 4, 3)))
Expand All @@ -245,7 +261,7 @@ class ParquetSuite extends IntegrationSuiteBase {

test("test parquet name conversion with column map"){
jdbcUpdate(
s"create or replace table $test_parquet_column_map (ONE int, TWO int, THREE int, Four int)"
s"create or replace table $test_column_map (ONE int, TWO int, THREE int, Four int)"
)


Expand All @@ -261,7 +277,7 @@ class ParquetSuite extends IntegrationSuiteBase {
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_column_map)
.option("dbtable", test_column_map)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("columnmap", Map(
"UPPER_CLASS_COL" -> "ONE",
Expand All @@ -274,7 +290,7 @@ class ParquetSuite extends IntegrationSuiteBase {
val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_column_map)
.option("dbtable", test_column_map)
.load()

checkAnswer(newDf, List(Row(1, 2, 3, null)))
Expand All @@ -289,7 +305,6 @@ class ParquetSuite extends IntegrationSuiteBase {
StructField("arr", ArrayType(IntegerType), nullable = false)
)
)
val tt: String = s"tt_$randomSuffix"
try {
val df = sparkSession
.createDataFrame(
Expand All @@ -306,15 +321,15 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", tt)
.option("dbtable", test_trim)
.option(Parameters.PARAM_TRIM_SPACE, "true")
.mode(SaveMode.Overwrite)
.save()

var loadDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.option("dbtable", test_trim)
.load()

assert(loadDf.select("str").collect().forall(row => row.toSeq.head.toString.length == 4))
Expand All @@ -324,22 +339,22 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", tt)
.option("dbtable", test_trim)
.mode(SaveMode.Overwrite)
.save()

loadDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptions)
.option("dbtable", tt)
.option("dbtable", test_trim)
.load()
val result = loadDf.select("str").collect()
assert(result.head.toSeq.head.toString.length == 4)
assert(result(1).toSeq.head.toString.length == 5)
assert(result(2).toSeq.head.toString.length == 6)

} finally {
jdbcUpdate(s"drop table if exists $tt")
jdbcUpdate(s"drop table if exists $test_trim")
}
}

Expand All @@ -363,14 +378,14 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_parquet_table)
.option("dbtable", test_date)
.mode(SaveMode.Overwrite)
.save()

val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_parquet_table)
.option("dbtable", test_date)
.load()
newDf.show()

Expand Down Expand Up @@ -402,14 +417,14 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_special_character)
.option("dbtable", test_special_char)
.mode(SaveMode.Overwrite)
.save()

val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_special_character)
.option("dbtable", test_special_char)
.load()
newDf.show()

Expand All @@ -424,7 +439,7 @@ class ParquetSuite extends IntegrationSuiteBase {

test("test parquet with special character to existing table"){
jdbcUpdate(
s"""create or replace table $test_special_character
s"""create or replace table $test_special_char_to_exist
|("timestamp1.()col" timestamp, "date1.()col" date)""".stripMargin
)

Expand All @@ -447,14 +462,14 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_special_character)
.option("dbtable", test_special_char_to_exist)
.mode(SaveMode.Append)
.save()

val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_special_character)
.option("dbtable", test_special_char_to_exist)
.load()
newDf.show()

Expand All @@ -469,7 +484,7 @@ class ParquetSuite extends IntegrationSuiteBase {

test("Test columnMap with parquet") {
jdbcUpdate(
s"create or replace table $test_parquet_column_map (ONE int, TWO int, THREE int, Four int)"
s"create or replace table $test_column_map_parquet (ONE int, TWO int, THREE int, Four int)"
)


Expand All @@ -489,7 +504,7 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_parquet_column_map)
.option("dbtable", test_column_map_parquet)
.option("columnmap", Map("UPPER_CLASS_COL" -> "ONE", "lower_class_col" -> "FOUR").toString())
.mode(SaveMode.Overwrite)
.save()
Expand All @@ -501,7 +516,7 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_parquet_column_map)
.option("dbtable", test_column_map_parquet)
.option("columnmap", Map("aaa" -> "ONE", "Mix_Class_Col" -> "FOUR").toString())
.mode(SaveMode.Append)
.save()
Expand All @@ -513,15 +528,15 @@ class ParquetSuite extends IntegrationSuiteBase {
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_parquet_column_map)
.option("dbtable", test_column_map_parquet)
.option("columnmap", Map("UPPER_CLASS_COL" -> "AAA", "Mix_Class_Col" -> "FOUR").toString())
.mode(SaveMode.Append)
.save()
}
}

test("test error when column map does not match") {
jdbcUpdate(s"create or replace table $dbtable1 (num int, str string)")
jdbcUpdate(s"create or replace table $test_column_map_not_match (num int, str string)")
// auto map
val schema1 = StructType(
List(StructField("str", StringType), StructField("NUM", IntegerType))
Expand All @@ -533,7 +548,7 @@ class ParquetSuite extends IntegrationSuiteBase {
df1.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", dbtable1)
.option("dbtable", test_column_map_not_match)
.mode(SaveMode.Append)
.save()
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/net/snowflake/spark/snowflake/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ object Parameters {
PARAM_TIMESTAMP_LTZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3",
PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3",
PARAM_TRIM_SPACE -> "false",
PARAM_USE_PARQUET_IN_WRITE -> "true"
PARAM_USE_PARQUET_IN_WRITE -> "false"

)

Expand Down Expand Up @@ -614,7 +614,7 @@ object Parameters {
* Use parquet form in download by default
*/
def useParquetInWrite(): Boolean = {
isTrue(parameters.getOrElse(PARAM_USE_PARQUET_IN_WRITE, "true"))
isTrue(parameters.getOrElse(PARAM_USE_PARQUET_IN_WRITE, "false"))

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,6 @@ private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper {
s"(${schemaString(schema, params)})")
.execute(bindVariableEnabled)(connection)

def createTableClone(newTableName: String,
oldTableName: String,
bindVariableEnabled: Boolean = true
): Unit =
(ConstantString("create or replace table") + Identifier(newTableName) +
"CLONE" + Identifier(oldTableName))
.execute(bindVariableEnabled)(connection)

def insertIntoTable(targetTableName: String,
sourceTableName: String,
targetTableSchema: StructType,
Expand Down Expand Up @@ -435,26 +427,6 @@ private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper {
.execute(bindVariableEnabled)(connection)
}

def createTableSelectFrom(name: String,
schema: StructType,
stagingTableName: String,
stagingTableSchema: StructType,
params: MergedParameters,
overwrite: Boolean,
temporary: Boolean,
bindVariableEnabled: Boolean = true): Unit = {
val columnNames = snowflakeStyleSchema(stagingTableSchema, params).fields
.map(_.name)
.mkString(",")
(ConstantString("create") +
(if (overwrite) "or replace" else "") +
(if (temporary) "temporary" else "") + "table" + Identifier(name) +
s"(${schemaString(schema, params)})" +
"as select" + s"$columnNames" +
"from" + Identifier(stagingTableName)
).execute(bindVariableEnabled)(connection)
}

def truncateTable(table: String,
bindVariableEnabled: Boolean = true): Unit =
(ConstantString("truncate") + table)
Expand Down

0 comments on commit 3405888

Please sign in to comment.