diff --git a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala index f9725b32..b8c48539 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala @@ -25,6 +25,7 @@ class ParquetSuite extends IntegrationSuiteBase { val test_column_map_not_match: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString val test_nested_dataframe: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString val test_no_staging_table: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString + val test_table_name: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString override def afterAll(): Unit = { jdbcUpdate(s"drop table if exists $test_all_type") @@ -41,6 +42,7 @@ class ParquetSuite extends IntegrationSuiteBase { jdbcUpdate(s"drop table if exists $test_column_map_not_match") jdbcUpdate(s"drop table if exists $test_nested_dataframe") jdbcUpdate(s"drop table if exists $test_no_staging_table") + jdbcUpdate(s"drop table if exists $test_table_name") super.afterAll() } @@ -707,4 +709,53 @@ class ParquetSuite extends IntegrationSuiteBase { val res = sparkSession.sql(s"show tables like '%${test_all_type}_STAGING%'").collect() assert(res.length == 0) } + + test("use parquet in structured type by default") { + // use CSV by default + sparkSession + .sql("select 1") + .write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_name) + .mode(SaveMode.Overwrite) + .save() + assert(Utils.getLastCopyLoad.contains("TYPE=CSV")) + + // use Parquet on structured types + sparkSession + .sql("select array(1, 2)") + .write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_name) + .mode(SaveMode.Overwrite) + .save() + assert(Utils.getLastCopyLoad.contains("TYPE=PARQUET")) + + // use Json on structured types when PARAM_USE_JSON_IN_STRUCTURED_DATA is true + sparkSession + .sql("select array(1, 2)") + .write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_name) + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") + .mode(SaveMode.Overwrite) + .save() + assert(Utils.getLastCopyLoad.contains("TYPE = JSON")) + + // PARAM_USE_PARQUET_IN_WRITE can overwrite PARAM_USE_JSON_IN_STRUCTURED_DATA + sparkSession + .sql("select array(1, 2)") + .write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_name) + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") + .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") + .mode(SaveMode.Overwrite) + .save() + assert(Utils.getLastCopyLoad.contains("TYPE=PARQUET")) + } } \ No newline at end of file diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala index a0cd6df7..8e58f9a9 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala @@ -1866,6 +1866,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { .format(SNOWFLAKE_SOURCE_NAME) .options(thisConnectorOptionsNoTable) .option("dbtable", test_table_write) + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") .mode(SaveMode.Overwrite) .save() @@ -1922,6 +1923,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { .format(SNOWFLAKE_SOURCE_NAME) .options(localSFOption) .option("dbtable", test_table_write) + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") .mode(SaveMode.Overwrite) .save() @@ -2004,6 +2006,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { .format(SNOWFLAKE_SOURCE_NAME) .options(thisConnectorOptionsNoTable) .option("dbtable", test_table_write) + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") .mode(SaveMode.Overwrite) .save() @@ -2019,6 +2022,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { .format(SNOWFLAKE_SOURCE_NAME) .options(thisConnectorOptionsNoTable) .option("dbtable", test_table_write) + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") .option(Parameters.PARAM_INTERNAL_QUOTE_JSON_FIELD_NAME, "false") .mode(SaveMode.Overwrite) .save() @@ -2031,6 +2035,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { .options(thisConnectorOptionsNoTable) .option("dbtable", test_table_write) .option(Parameters.PARAM_INTERNAL_QUOTE_JSON_FIELD_NAME, "false") + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") .mode(SaveMode.Overwrite) .save() diff --git a/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala index 3ca5af58..23e3560d 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala @@ -164,6 +164,7 @@ class VariantTypeSuite extends IntegrationSuiteBase { .options(connectorOptionsNoTable) .option("dbtable", tableName2) .option(Parameters.PARAM_INTERNAL_USE_PARSE_JSON_FOR_WRITE, "true") + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") .mode(SaveMode.Overwrite) .save() @@ -242,6 +243,7 @@ class VariantTypeSuite extends IntegrationSuiteBase { .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptionsNoTable) .option("dbtable", tableName4) + .option(Parameters.PARAM_USE_JSON_IN_STRUCTURED_DATA, "true") .mode(SaveMode.Overwrite) .save() diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 63a5b10c..193c73b1 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -260,9 +260,18 @@ object Parameters { val BOOLEAN_VALUES_FALSE: Set[String] = Set("off", "no", "false", "0", "disabled") - // enable parquet format + // enable parquet format when loading data from Spark to Snowflake. + // When enabled, Spark connector will only use Parquet file format. val PARAM_USE_PARQUET_IN_WRITE: String = knownParam("use_parquet_in_write") + // By default, Spark connector uses CSV format when loading data from Spark to Snowflake. + // If the dataframe contains any structured type, Spark connector will use Parquet + // format instead of CSV. + // When this parameter is enabled, Spark connector will use JSON format when loading + // structured data but not Parquet. + // it will be ignored if USE_PARQUET_IN_WRITE parameter is enabled. + val PARAM_USE_JSON_IN_STRUCTURED_DATA: String = knownParam("use_json_in_structured_data") + /** * Helper method to check if a given string represents some form * of "true" value, see BOOLEAN_VALUES_TRUE @@ -297,7 +306,8 @@ object Parameters { PARAM_TIMESTAMP_LTZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3", PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3", PARAM_TRIM_SPACE -> "false", - PARAM_USE_PARQUET_IN_WRITE -> "false" + PARAM_USE_PARQUET_IN_WRITE -> "false", + PARAM_USE_JSON_IN_STRUCTURED_DATA -> "false" ) @@ -613,13 +623,25 @@ object Parameters { def createPerQueryTempDir(): String = Utils.makeTempPath(rootTempDir) /** - * Use parquet form in download by default + * Use parquet format when loading data from Spark to Snowflake */ def useParquetInWrite(): Boolean = { isTrue(parameters.getOrElse(PARAM_USE_PARQUET_IN_WRITE, "false")) } + /** + * Use JSON format when loading structured data from Spark to Snowflake + */ + def useJsonInWrite(): Boolean = { + if (useParquetInWrite()) { + // USE_PARQUET_IN_WRITE parameter can overwrite this parameter + false + } else { + isTrue(parameters.getOrElse(PARAM_USE_JSON_IN_STRUCTURED_DATA, "false")) + } + } + /** * The Snowflake table to be used as the target when loading or writing data. */ diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala index 14e94eb3..5257ac79 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala @@ -57,7 +57,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { if (params.useParquetInWrite()) { SupportedFormat.PARQUET } else if (Utils.containVariant(data.schema)){ - SupportedFormat.JSON + if (params.useJsonInWrite()) SupportedFormat.JSON else SupportedFormat.PARQUET } else { SupportedFormat.CSV @@ -74,7 +74,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { ) params.setColumnMap(Option(data.schema), toSchema) } finally conn.close() - } else if (params.columnMap.isDefined && params.useParquetInWrite()){ + } else if (params.columnMap.isDefined && format == SupportedFormat.PARQUET){ val conn = jdbcWrapper.getConnector(params) try { val toSchema = Utils.removeQuote( @@ -94,7 +94,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { } finally conn.close() } - if (params.useParquetInWrite()){ + if (format == SupportedFormat.PARQUET){ val conn = jdbcWrapper.getConnector(params) try{ if (jdbcWrapper.tableExists(params, params.table.get.name)){ diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index 2c104bd1..2d8dc0cb 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -327,7 +327,7 @@ private[io] object StageWriter { if (!tableExists) { writeTableState.createTable(tableName, - if (params.useParquetInWrite()) params.toSnowflakeSchema(schema) else schema, + if (format == SupportedFormat.PARQUET) params.toSnowflakeSchema(schema) else schema, params) } else if (params.truncateTable && saveMode == SaveMode.Overwrite) { writeTableState.truncateTable(tableName) @@ -425,7 +425,7 @@ private[io] object StageWriter { if (saveMode == SaveMode.Overwrite || !tableExists) { conn.createTable(targetTable.name, - if (params.useParquetInWrite()) params.toSnowflakeSchema(schema) else schema, + if (format == SupportedFormat.PARQUET) params.toSnowflakeSchema(schema) else schema, params, overwrite = false, temporary = false) } @@ -904,7 +904,7 @@ private[io] object StageWriter { val fromString = ConstantString(s"FROM @$tempStage/$prefix/") ! val mappingList: Option[List[(Int, String)]] = - if (params.useParquetInWrite()) None else params.columnMap match { + if (format == SupportedFormat.PARQUET) None else params.columnMap match { case Some(map) => Some(map.toList.map { case (key, value) =>