From 53d404e1e074662ca4d45df5f537dbf70cbd35cb Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 7 Nov 2024 11:03:23 -0800 Subject: [PATCH] SNOW-1790870:Use copy transform in parquet format (#594) * remove intermedia table * drop temp table * use copy transform * add test for not using staging table --- .../spark/snowflake/ParquetSuite.scala | 58 ++++++++ .../spark/snowflake/io/StageWriter.scala | 126 +++++++----------- 2 files changed, 108 insertions(+), 76 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala index 811b845f..f9725b32 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala @@ -24,6 +24,7 @@ class ParquetSuite extends IntegrationSuiteBase { val test_column_map_parquet: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString val test_column_map_not_match: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString val test_nested_dataframe: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString + val test_no_staging_table: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString override def afterAll(): Unit = { jdbcUpdate(s"drop table if exists $test_all_type") @@ -39,6 +40,7 @@ class ParquetSuite extends IntegrationSuiteBase { jdbcUpdate(s"drop table if exists $test_column_map_parquet") jdbcUpdate(s"drop table if exists $test_column_map_not_match") jdbcUpdate(s"drop table if exists $test_nested_dataframe") + jdbcUpdate(s"drop table if exists $test_no_staging_table") super.afterAll() } @@ -649,4 +651,60 @@ class ParquetSuite extends IntegrationSuiteBase { assert(result(2).getStruct(3).getString(0) == "ghi") assert(result(2).getAs[Row]("OBJ").getAs[String]("str") == "ghi") } + + test("test parquet not using staging table") { + val data: RDD[Row] = sc.makeRDD( + List( + Row( + 1, + "string value", + 123456789L, + 123.45, + true, + BigDecimal("12345.6789").bigDecimal, + Timestamp.valueOf("2023-09-16 10:15:30"), + Date.valueOf("2023-01-01") + ) + ) + ) + + val schema = StructType(List( + StructField("INT_COL", IntegerType, true), + StructField("STRING_COL", StringType, true), + StructField("LONG_COL", LongType, true), + StructField("DOUBLE_COL", DoubleType, true), + StructField("BOOLEAN_COL", BooleanType, true), + StructField("DECIMAL_COL", DecimalType(20, 10), true), + StructField("TIMESTAMP_COL", TimestampType, true), + StructField("DATE_COL", DateType, true) + )) + val df = sparkSession.createDataFrame(data, schema) + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") + .option("usestagingtable", "false") + .option("dbtable", test_no_staging_table) + .mode(SaveMode.Overwrite) + .save() + + + val newDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_no_staging_table) + .load() + + val expectedAnswer = List( + Row(1, "string value", 123456789, 123.45, + true, BigDecimal("12345.6789").bigDecimal.setScale(10), + Timestamp.valueOf("2023-09-16 10:15:30"), Date.valueOf("2023-01-01") + ) + ) + checkAnswer(newDf, expectedAnswer) + + // assert no staging table is left + val res = sparkSession.sql(s"show tables like '%${test_all_type}_STAGING%'").collect() + assert(res.length == 0) + } } \ No newline at end of file diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index fe0e2e56..2c104bd1 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -283,7 +283,7 @@ private[io] object StageWriter { tempStage: String, format: SupportedFormat, fileUploadResults: List[FileUploadResult]): Unit = { - if (params.useStagingTable || !params.truncateTable || params.useParquetInWrite()) { + if (params.useStagingTable || !params.truncateTable) { writeToTableWithStagingTable(sqlContext, conn, schema, saveMode, params, file, tempStage, format, fileUploadResults) } else { @@ -326,7 +326,9 @@ private[io] object StageWriter { // If create table if table doesn't exist if (!tableExists) { - writeTableState.createTable(tableName, schema, params) + writeTableState.createTable(tableName, + if (params.useParquetInWrite()) params.toSnowflakeSchema(schema) else schema, + params) } else if (params.truncateTable && saveMode == SaveMode.Overwrite) { writeTableState.truncateTable(tableName) } @@ -390,20 +392,8 @@ private[io] object StageWriter { getStageTableName(table.name) } ) - - val relayTable = TableName( - if (params.stagingTableNameRemoveQuotesOnly) { - // NOTE: This is the staging table name generation for SC 2.8.1 and earlier. - // It is kept for back-compatibility and it will be removed later without any notice. - s"${table.name.replace('"', '_')}_staging_${Math.abs(Random.nextInt()).toString}" - } else { - getStageTableName(table.name) - } - ) - assert(!params.useParquetInWrite() || params.useStagingTable) val targetTable = - if ((saveMode == SaveMode.Overwrite && params.useStagingTable) || - params.useParquetInWrite()) { + if (saveMode == SaveMode.Overwrite && params.useStagingTable) { tempTable } else { table @@ -419,43 +409,26 @@ private[io] object StageWriter { } else { DefaultJDBCWrapper.tableExists(params, table.toString) } - - if (params.useParquetInWrite()){ - // temporary table to store parquet file - conn.createTable(tempTable.name, schema, params, - overwrite = false, temporary = true) - - if (saveMode == SaveMode.Overwrite){ - conn.createTable(relayTable.name, params.toSnowflakeSchema(schema), params, - overwrite = false, temporary = false) - } else { - if (!tableExists) { - conn.createTable(table.name, params.toSnowflakeSchema(schema), params, - overwrite = false, temporary = false) + // purge tables when overwriting + if (saveMode == SaveMode.Overwrite && tableExists) { + if (params.useStagingTable) { + if (params.truncateTable) { + conn.createTableLike(tempTable.name, table.name) } - } - - } else { - // purge tables when overwriting - if (saveMode == SaveMode.Overwrite && tableExists) { - if (params.useStagingTable) { - if (params.truncateTable) { - conn.createTableLike(tempTable.name, table.name) - } - } else if (params.truncateTable) conn.truncateTable(table.name) - else conn.dropTable(table.name) - } - - // If the SaveMode is 'Append' and the target exists, skip - // CREATE TABLE IF NOT EXIST command. This command doesn't actually - // create a table but it needs CREATE TABLE privilege. - if (saveMode == SaveMode.Overwrite || !tableExists) - { - conn.createTable(targetTable.name, schema, params, - overwrite = false, temporary = false) - } + } else if (params.truncateTable) conn.truncateTable(table.name) + else conn.dropTable(table.name) } + // If the SaveMode is 'Append' and the target exists, skip + // CREATE TABLE IF NOT EXIST command. This command doesn't actually + // create a table but it needs CREATE TABLE privilege. + if (saveMode == SaveMode.Overwrite || !tableExists) + { + conn.createTable(targetTable.name, + if (params.useParquetInWrite()) params.toSnowflakeSchema(schema) else schema, + params, + overwrite = false, temporary = false) + } // pre actions Utils.executePreActions( @@ -486,34 +459,18 @@ private[io] object StageWriter { Option(targetTable) ) - if (params.useParquetInWrite()) { - if (saveMode == SaveMode.Overwrite){ - conn.insertIntoTable(relayTable.name, tempTable.name, - params.toSnowflakeSchema(schema), schema, params) - if (tableExists) { - conn.swapTable(table.name, relayTable.name) - conn.dropTable(relayTable.name) - } else { - conn.renameTable(table.name, relayTable.name) - } + + if (saveMode == SaveMode.Overwrite && params.useStagingTable) { + if (tableExists) { + conn.swapTable(table.name, tempTable.name) + conn.dropTable(tempTable.name) } else { - conn.insertIntoTable(table.name, tempTable.name, - params.toSnowflakeSchema(schema), schema, params) - conn.commit() + conn.renameTable(table.name, tempTable.name) } - conn.dropTable(tempTable.name) } else { - if (saveMode == SaveMode.Overwrite && params.useStagingTable) { - if (tableExists) { - conn.swapTable(table.name, tempTable.name) - conn.dropTable(tempTable.name) - } else { - conn.renameTable(table.name, tempTable.name) - } - } else { - conn.commit() - } + conn.commit() } + } catch { case e: Exception => // snowflake-todo: try to provide more error information, @@ -830,7 +787,20 @@ private[io] object StageWriter { ): SnowflakeSQLStatement = format match { case SupportedFormat.PARQUET => - EmptySnowflakeSQLStatement() + ConstantString("(") + params.toSnowflakeSchema(schema) + .map( + field => + if (params.quoteJsonFieldName) { + if (params.keepOriginalColumnNameCase) { + Utils.quotedNameIgnoreCase(field.name) + } else { + Utils.ensureQuoted(field.name) + } + } else { + field.name + } + ) + .mkString(",") + ")" case SupportedFormat.JSON => val tableSchema = DefaultJDBCWrapper.resolveTable(conn, table.name, params) @@ -886,7 +856,12 @@ private[io] object StageWriter { ): SnowflakeSQLStatement = format match { case SupportedFormat.PARQUET => - from + ConstantString("from (select") + + schema.map( + field => + "$1:" + "\"" + field.name + "\"" + ).mkString(",") + + from + "tmp)" case SupportedFormat.JSON => val columnPrefix = if (params.useParseJsonForWrite) "parse_json($1):" else "$1:" if (list.isEmpty || list.get.isEmpty) { @@ -971,7 +946,6 @@ private[io] object StageWriter { | TYPE=PARQUET | USE_VECTORIZED_SCANNER=TRUE | ) - | MATCH_BY_COLUMN_NAME = CASE_SENSITIVE """.stripMargin) ! case SupportedFormat.CSV => ConstantString(s"""