Skip to content

Commit

Permalink
SNOW-1790870:Use copy transform in parquet format (#594)
Browse files Browse the repository at this point in the history
* remove intermedia table

* drop temp table

* use copy transform

* add test for not using staging table
  • Loading branch information
sfc-gh-yuwang authored Nov 7, 2024
1 parent bc76a33 commit 53d404e
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 76 deletions.
58 changes: 58 additions & 0 deletions src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ParquetSuite extends IntegrationSuiteBase {
val test_column_map_parquet: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map_not_match: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_nested_dataframe: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_no_staging_table: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString

override def afterAll(): Unit = {
jdbcUpdate(s"drop table if exists $test_all_type")
Expand All @@ -39,6 +40,7 @@ class ParquetSuite extends IntegrationSuiteBase {
jdbcUpdate(s"drop table if exists $test_column_map_parquet")
jdbcUpdate(s"drop table if exists $test_column_map_not_match")
jdbcUpdate(s"drop table if exists $test_nested_dataframe")
jdbcUpdate(s"drop table if exists $test_no_staging_table")
super.afterAll()
}

Expand Down Expand Up @@ -649,4 +651,60 @@ class ParquetSuite extends IntegrationSuiteBase {
assert(result(2).getStruct(3).getString(0) == "ghi")
assert(result(2).getAs[Row]("OBJ").getAs[String]("str") == "ghi")
}

test("test parquet not using staging table") {
val data: RDD[Row] = sc.makeRDD(
List(
Row(
1,
"string value",
123456789L,
123.45,
true,
BigDecimal("12345.6789").bigDecimal,
Timestamp.valueOf("2023-09-16 10:15:30"),
Date.valueOf("2023-01-01")
)
)
)

val schema = StructType(List(
StructField("INT_COL", IntegerType, true),
StructField("STRING_COL", StringType, true),
StructField("LONG_COL", LongType, true),
StructField("DOUBLE_COL", DoubleType, true),
StructField("BOOLEAN_COL", BooleanType, true),
StructField("DECIMAL_COL", DecimalType(20, 10), true),
StructField("TIMESTAMP_COL", TimestampType, true),
StructField("DATE_COL", DateType, true)
))
val df = sparkSession.createDataFrame(data, schema)
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("usestagingtable", "false")
.option("dbtable", test_no_staging_table)
.mode(SaveMode.Overwrite)
.save()


val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_no_staging_table)
.load()

val expectedAnswer = List(
Row(1, "string value", 123456789, 123.45,
true, BigDecimal("12345.6789").bigDecimal.setScale(10),
Timestamp.valueOf("2023-09-16 10:15:30"), Date.valueOf("2023-01-01")
)
)
checkAnswer(newDf, expectedAnswer)

// assert no staging table is left
val res = sparkSession.sql(s"show tables like '%${test_all_type}_STAGING%'").collect()
assert(res.length == 0)
}
}
126 changes: 50 additions & 76 deletions src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ private[io] object StageWriter {
tempStage: String,
format: SupportedFormat,
fileUploadResults: List[FileUploadResult]): Unit = {
if (params.useStagingTable || !params.truncateTable || params.useParquetInWrite()) {
if (params.useStagingTable || !params.truncateTable) {
writeToTableWithStagingTable(sqlContext, conn, schema, saveMode, params,
file, tempStage, format, fileUploadResults)
} else {
Expand Down Expand Up @@ -326,7 +326,9 @@ private[io] object StageWriter {
// If create table if table doesn't exist
if (!tableExists)
{
writeTableState.createTable(tableName, schema, params)
writeTableState.createTable(tableName,
if (params.useParquetInWrite()) params.toSnowflakeSchema(schema) else schema,
params)
} else if (params.truncateTable && saveMode == SaveMode.Overwrite) {
writeTableState.truncateTable(tableName)
}
Expand Down Expand Up @@ -390,20 +392,8 @@ private[io] object StageWriter {
getStageTableName(table.name)
}
)

val relayTable = TableName(
if (params.stagingTableNameRemoveQuotesOnly) {
// NOTE: This is the staging table name generation for SC 2.8.1 and earlier.
// It is kept for back-compatibility and it will be removed later without any notice.
s"${table.name.replace('"', '_')}_staging_${Math.abs(Random.nextInt()).toString}"
} else {
getStageTableName(table.name)
}
)
assert(!params.useParquetInWrite() || params.useStagingTable)
val targetTable =
if ((saveMode == SaveMode.Overwrite && params.useStagingTable) ||
params.useParquetInWrite()) {
if (saveMode == SaveMode.Overwrite && params.useStagingTable) {
tempTable
} else {
table
Expand All @@ -419,43 +409,26 @@ private[io] object StageWriter {
} else {
DefaultJDBCWrapper.tableExists(params, table.toString)
}

if (params.useParquetInWrite()){
// temporary table to store parquet file
conn.createTable(tempTable.name, schema, params,
overwrite = false, temporary = true)

if (saveMode == SaveMode.Overwrite){
conn.createTable(relayTable.name, params.toSnowflakeSchema(schema), params,
overwrite = false, temporary = false)
} else {
if (!tableExists) {
conn.createTable(table.name, params.toSnowflakeSchema(schema), params,
overwrite = false, temporary = false)
// purge tables when overwriting
if (saveMode == SaveMode.Overwrite && tableExists) {
if (params.useStagingTable) {
if (params.truncateTable) {
conn.createTableLike(tempTable.name, table.name)
}
}

} else {
// purge tables when overwriting
if (saveMode == SaveMode.Overwrite && tableExists) {
if (params.useStagingTable) {
if (params.truncateTable) {
conn.createTableLike(tempTable.name, table.name)
}
} else if (params.truncateTable) conn.truncateTable(table.name)
else conn.dropTable(table.name)
}

// If the SaveMode is 'Append' and the target exists, skip
// CREATE TABLE IF NOT EXIST command. This command doesn't actually
// create a table but it needs CREATE TABLE privilege.
if (saveMode == SaveMode.Overwrite || !tableExists)
{
conn.createTable(targetTable.name, schema, params,
overwrite = false, temporary = false)
}
} else if (params.truncateTable) conn.truncateTable(table.name)
else conn.dropTable(table.name)
}

// If the SaveMode is 'Append' and the target exists, skip
// CREATE TABLE IF NOT EXIST command. This command doesn't actually
// create a table but it needs CREATE TABLE privilege.
if (saveMode == SaveMode.Overwrite || !tableExists)
{
conn.createTable(targetTable.name,
if (params.useParquetInWrite()) params.toSnowflakeSchema(schema) else schema,
params,
overwrite = false, temporary = false)
}

// pre actions
Utils.executePreActions(
Expand Down Expand Up @@ -486,34 +459,18 @@ private[io] object StageWriter {
Option(targetTable)
)

if (params.useParquetInWrite()) {
if (saveMode == SaveMode.Overwrite){
conn.insertIntoTable(relayTable.name, tempTable.name,
params.toSnowflakeSchema(schema), schema, params)
if (tableExists) {
conn.swapTable(table.name, relayTable.name)
conn.dropTable(relayTable.name)
} else {
conn.renameTable(table.name, relayTable.name)
}

if (saveMode == SaveMode.Overwrite && params.useStagingTable) {
if (tableExists) {
conn.swapTable(table.name, tempTable.name)
conn.dropTable(tempTable.name)
} else {
conn.insertIntoTable(table.name, tempTable.name,
params.toSnowflakeSchema(schema), schema, params)
conn.commit()
conn.renameTable(table.name, tempTable.name)
}
conn.dropTable(tempTable.name)
} else {
if (saveMode == SaveMode.Overwrite && params.useStagingTable) {
if (tableExists) {
conn.swapTable(table.name, tempTable.name)
conn.dropTable(tempTable.name)
} else {
conn.renameTable(table.name, tempTable.name)
}
} else {
conn.commit()
}
conn.commit()
}

} catch {
case e: Exception =>
// snowflake-todo: try to provide more error information,
Expand Down Expand Up @@ -830,7 +787,20 @@ private[io] object StageWriter {
): SnowflakeSQLStatement =
format match {
case SupportedFormat.PARQUET =>
EmptySnowflakeSQLStatement()
ConstantString("(") + params.toSnowflakeSchema(schema)
.map(
field =>
if (params.quoteJsonFieldName) {
if (params.keepOriginalColumnNameCase) {
Utils.quotedNameIgnoreCase(field.name)
} else {
Utils.ensureQuoted(field.name)
}
} else {
field.name
}
)
.mkString(",") + ")"
case SupportedFormat.JSON =>
val tableSchema =
DefaultJDBCWrapper.resolveTable(conn, table.name, params)
Expand Down Expand Up @@ -886,7 +856,12 @@ private[io] object StageWriter {
): SnowflakeSQLStatement =
format match {
case SupportedFormat.PARQUET =>
from
ConstantString("from (select") +
schema.map(
field =>
"$1:" + "\"" + field.name + "\""
).mkString(",") +
from + "tmp)"
case SupportedFormat.JSON =>
val columnPrefix = if (params.useParseJsonForWrite) "parse_json($1):" else "$1:"
if (list.isEmpty || list.get.isEmpty) {
Expand Down Expand Up @@ -971,7 +946,6 @@ private[io] object StageWriter {
| TYPE=PARQUET
| USE_VECTORIZED_SCANNER=TRUE
| )
| MATCH_BY_COLUMN_NAME = CASE_SENSITIVE
""".stripMargin) !
case SupportedFormat.CSV =>
ConstantString(s"""
Expand Down

0 comments on commit 53d404e

Please sign in to comment.