From 1b74cd095849b067a232439450f2bb4cc34d852e Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 30 Oct 2024 11:32:13 -0700 Subject: [PATCH 1/8] fix parameter --- src/main/scala/net/snowflake/spark/snowflake/Parameters.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 44690aee..63a5b10c 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -261,7 +261,7 @@ object Parameters { Set("off", "no", "false", "0", "disabled") // enable parquet format - val PARAM_USE_PARQUET_IN_WRITE: String = knownParam("use_parquet_in_write ") + val PARAM_USE_PARQUET_IN_WRITE: String = knownParam("use_parquet_in_write") /** * Helper method to check if a given string represents some form From d7a952b21dc4bab8509241d441c9f351273a632b Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 30 Oct 2024 11:40:41 -0700 Subject: [PATCH 2/8] remove timestamp NTZ to be compatible with Spark 3.3- --- .../net/snowflake/spark/snowflake/ParquetSuite.scala | 8 ++++---- .../net/snowflake/spark/snowflake/io/ParquetUtils.scala | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala index 134ece03..81ffc29e 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala @@ -502,7 +502,7 @@ class ParquetSuite extends IntegrationSuiteBase { // throw exception because only support SaveMode.Append assertThrows[UnsupportedOperationException] { df.write - .format(SNOWFLAKE_SOURCE_SHORT_NAME) + .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptionsNoTable) .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") .option("dbtable", test_column_map_parquet) @@ -514,7 +514,7 @@ class ParquetSuite extends IntegrationSuiteBase { // throw exception because "aaa" is not a column name of DF assertThrows[IllegalArgumentException] { df.write - .format(SNOWFLAKE_SOURCE_SHORT_NAME) + .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptionsNoTable) .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") .option("dbtable", test_column_map_parquet) @@ -526,7 +526,7 @@ class ParquetSuite extends IntegrationSuiteBase { // throw exception because "AAA" is not a column name of table in snowflake database assertThrows[IllegalArgumentException] { df.write - .format(SNOWFLAKE_SOURCE_SHORT_NAME) + .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptionsNoTable) .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") .option("dbtable", test_column_map_parquet) @@ -547,7 +547,7 @@ class ParquetSuite extends IntegrationSuiteBase { assertThrows[SQLException]{ df1.write - .format(SNOWFLAKE_SOURCE_SHORT_NAME) + .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptionsNoTable) .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") .option("dbtable", test_column_map_not_match) diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala b/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala index 2376c7cc..ca7b9990 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala @@ -1,7 +1,7 @@ package net.snowflake.spark.snowflake.io import org.apache.avro.{Schema, SchemaBuilder} -import org.apache.avro.SchemaBuilder.{BaseFieldTypeBuilder, BaseTypeBuilder, FieldDefault, RecordBuilder, nullable} +import org.apache.avro.SchemaBuilder.{BaseFieldTypeBuilder, BaseTypeBuilder, FieldDefault, RecordBuilder} import org.apache.parquet.io.{OutputFile, PositionOutputStream} import org.apache.spark.sql.types._ @@ -53,7 +53,7 @@ object ParquetUtils { builder.stringBuilder() .prop("logicalType", "date") .endString() - case TimestampType | TimestampNTZType => + case TimestampType => builder.stringBuilder() .prop("logicalType", " timestamp-micros") .endString() From cc16590882e8848e2dc57fa40841367ebbb22b34 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 30 Oct 2024 11:51:37 -0700 Subject: [PATCH 3/8] refactor cloud operations --- .../snowflake/io/CloudStorageOperations.scala | 252 +++--------------- 1 file changed, 41 insertions(+), 211 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala index fcb1da2e..18f4e2e1 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala @@ -660,245 +660,75 @@ sealed trait CloudStorage { // partition can be empty, can't generate empty parquet file, // so skip the empty partition. if (input.nonEmpty) { - if (storageInfo.isDefined) { - val uploadStream = createUploadStream( + val uploadStream = if (storageInfo.isDefined) { + createUploadStream( fileName, Some(directory), // compress, if (format == SupportedFormat.PARQUET) false else compress, storageInfo.get) - try { - format match { - case SupportedFormat.PARQUET => - val rows = input.asInstanceOf[Iterator[GenericData.Record]].toSeq - val writer = AvroParquetWriter.builder[GenericData.Record]( - new ParquetUtils.StreamOutputFile(uploadStream) - ).withSchema(rows.head.getSchema) - .withCompressionCodec(CompressionCodecName.SNAPPY) - .build() - rows.foreach(writer.write) - writer.close() - case _ => - val rows = input.asInstanceOf[Iterator[String]] - while (rows.hasNext) { - val oneRow = rows.next.getBytes("UTF-8") - uploadStream.write(oneRow) - uploadStream.write('\n') - rowCount += 1 - dataSize += (oneRow.size + 1) - } - } - } finally { - uploadStream.close() - } - val endTime = System.currentTimeMillis() - processTimeInfo = - s"""read_and_upload_time: - | ${Utils.getTimeString(endTime - startTime)} - |""".stripMargin.filter(_ >= ' ') - } else if (fileTransferMetadata.isDefined) { + } else { + new ByteArrayOutputStream(4 * 1024 * 1024) + } + try { format match { case SupportedFormat.PARQUET => val rows = input.asInstanceOf[Iterator[GenericData.Record]].toSeq - val outputStream = new ByteArrayOutputStream() val writer = AvroParquetWriter.builder[GenericData.Record]( - new ParquetUtils.StreamOutputFile(outputStream) + new ParquetUtils.StreamOutputFile(uploadStream) ).withSchema(rows.head.getSchema) .withCompressionCodec(CompressionCodecName.SNAPPY) .build() rows.foreach(writer.write) + rowCount += rows.size + dataSize += writer.getDataSize writer.close() - - val data = outputStream.toByteArray - dataSize = data.size - outputStream.close() - - // Set up proxy info if it is configured. - val proxyProperties = new Properties() - proxyInfo match { - case Some(proxyInfoValue) => - proxyInfoValue.setProxyForJDBC(proxyProperties) - case None => - } - - // Upload data with FileTransferMetadata - val startUploadTime = System.currentTimeMillis() - val inStream = new ByteArrayInputStream(data) - SnowflakeFileTransferAgent.uploadWithoutConnection( - SnowflakeFileTransferConfig.Builder.newInstance() - .setSnowflakeFileTransferMetadata(fileTransferMetadata.get) - .setUploadStream(inStream) - .setRequireCompress(false) - .setDestFileName(fileName) - .setOcspMode(OCSPMode.FAIL_OPEN) - .setProxyProperties(proxyProperties) - .build()) - val endTime = System.currentTimeMillis() - processTimeInfo = - s"""read_and_upload_time: - | ${Utils.getTimeString(endTime - startTime)} - | read_time: ${Utils.getTimeString(startUploadTime - startTime)} - | upload_time: ${Utils.getTimeString(endTime - startUploadTime)} - |""".stripMargin.filter(_ >= ' ') - case _ => - val outputStream = new ByteArrayOutputStream(4 * 1024 * 1024) val rows = input.asInstanceOf[Iterator[String]] while (rows.hasNext) { val oneRow = rows.next.getBytes("UTF-8") - outputStream.write(oneRow) - outputStream.write('\n') + uploadStream.write(oneRow) + uploadStream.write('\n') rowCount += 1 dataSize += (oneRow.size + 1) } - val data = outputStream.toByteArray - dataSize = data.size - outputStream.close() - - // Set up proxy info if it is configured. - val proxyProperties = new Properties() - proxyInfo match { - case Some(proxyInfoValue) => - proxyInfoValue.setProxyForJDBC(proxyProperties) - case None => - } + } + } finally { + if (storageInfo.isDefined) { + uploadStream.close() + } else { + val data = uploadStream.asInstanceOf[ByteArrayOutputStream].toByteArray + dataSize = data.size + uploadStream.close() + // Set up proxy info if it is configured. + val proxyProperties = new Properties() + proxyInfo match { + case Some(proxyInfoValue) => + proxyInfoValue.setProxyForJDBC(proxyProperties) + case None => + } - // Upload data with FileTransferMetadata - val startUploadTime = System.currentTimeMillis() - val inStream = new ByteArrayInputStream(data) - SnowflakeFileTransferAgent.uploadWithoutConnection( - SnowflakeFileTransferConfig.Builder.newInstance() - .setSnowflakeFileTransferMetadata(fileTransferMetadata.get) - .setUploadStream(inStream) - .setRequireCompress(compress) - .setDestFileName(fileName) - .setOcspMode(OCSPMode.FAIL_OPEN) - .setProxyProperties(proxyProperties) - .build()) - val endTime = System.currentTimeMillis() - processTimeInfo = - s"""read_and_upload_time: - | ${Utils.getTimeString(endTime - startTime)} - | read_time: ${Utils.getTimeString(startUploadTime - startTime)} - | upload_time: ${Utils.getTimeString(endTime - startUploadTime)} - |""".stripMargin.filter(_ >= ' ') + val inStream = new ByteArrayInputStream(data) + SnowflakeFileTransferAgent.uploadWithoutConnection( + SnowflakeFileTransferConfig.Builder.newInstance() + .setSnowflakeFileTransferMetadata(fileTransferMetadata.get) + .setUploadStream(inStream) + .setRequireCompress(false) + .setDestFileName(fileName) + .setOcspMode(OCSPMode.FAIL_OPEN) + .setProxyProperties(proxyProperties) + .build()) + val endTime = System.currentTimeMillis() + processTimeInfo = + s"""read_and_upload_time: + | ${Utils.getTimeString(endTime - startTime)} + |""".stripMargin.filter(_ >= ' ') } } } else { logger.info(s"Empty partition, skipped file $fileName") } - - // todo: handle GCP - // When attempt number is smaller than 2, throw exception - if (TaskContext.get().attemptNumber() < 2) { - TestHook.raiseExceptionIfTestFlagEnabled( - TestHookFlag.TH_GCS_UPLOAD_RAISE_EXCEPTION, - "Negative test to raise error when uploading data for the first two attempts" - ) - } - - CloudStorageOperations.log.info( - s"""${SnowflakeResultSetRDD.WORKER_LOG_PREFIX}: - | Finish writing partition ID:$partitionID $fileName - | write row count is $rowCount. - | Uncompressed data size is ${Utils.getSizeString(dataSize)}. - | $processTimeInfo - |""".stripMargin.filter(_ >= ' ')) - - new SingleElementIterator(new FileUploadResult(s"$directory/$fileName", dataSize, rowCount)) - } - // Read data and upload to cloud storage - private def doUploadPartitionV1(rows: Iterator[String], - format: SupportedFormat, - compress: Boolean, - directory: String, - partitionID: Int, - storageInfo: Option[Map[String, String]], - fileTransferMetadata: Option[SnowflakeFileTransferMetadata] - ) - : SingleElementIterator = { - val fileName = getFileName(partitionID, format, compress) - - CloudStorageOperations.log.info( - s"""${SnowflakeResultSetRDD.WORKER_LOG_PREFIX}: - | Start writing partition ID:$partitionID as $fileName - | TaskInfo: ${SnowflakeTelemetry.getTaskInfo().toPrettyString} - |""".stripMargin.filter(_ >= ' ')) - - // Read data and upload to cloud storage - var rowCount: Long = 0 - var dataSize: Long = 0 - var processTimeInfo = "" - val startTime = System.currentTimeMillis() - if (storageInfo.isDefined) { - // For AWS and Azure, the rows are written to OutputStream as they are read. - var uploadStream: Option[OutputStream] = None - while (rows.hasNext) { - // Defer to create the upload stream to avoid empty files. - if (uploadStream.isEmpty) { - uploadStream = Some(createUploadStream( - fileName, Some(directory), compress, storageInfo.get)) - } - val oneRow = rows.next.getBytes("UTF-8") - uploadStream.get.write(oneRow) - uploadStream.get.write('\n') - rowCount += 1 - dataSize += (oneRow.size + 1) - } - if (uploadStream.isDefined) { - uploadStream.get.close() - } - - val endTime = System.currentTimeMillis() - processTimeInfo = - s"""read_and_upload_time: - | ${Utils.getTimeString(endTime - startTime)} - |""".stripMargin.filter(_ >= ' ') - } - // For GCP, the rows are cached and then uploaded. - else if (fileTransferMetadata.isDefined) { - // cache the data in buffer - val outputStream = new ByteArrayOutputStream(4 * 1024 * 1024) - while (rows.hasNext) { - outputStream.write(rows.next.getBytes("UTF-8")) - outputStream.write('\n') - rowCount += 1 - } - val data = outputStream.toByteArray - dataSize = data.size - outputStream.close() - - // Set up proxy info if it is configured. - val proxyProperties = new Properties() - proxyInfo match { - case Some(proxyInfoValue) => - proxyInfoValue.setProxyForJDBC(proxyProperties) - case None => - } - - // Upload data with FileTransferMetadata - val startUploadTime = System.currentTimeMillis() - val inStream = new ByteArrayInputStream(data) - SnowflakeFileTransferAgent.uploadWithoutConnection( - SnowflakeFileTransferConfig.Builder.newInstance() - .setSnowflakeFileTransferMetadata(fileTransferMetadata.get) - .setUploadStream(inStream) - .setRequireCompress(compress) - .setDestFileName(fileName) - .setOcspMode(OCSPMode.FAIL_OPEN) - .setProxyProperties(proxyProperties) - .build()) - - val endTime = System.currentTimeMillis() - processTimeInfo = - s"""read_and_upload_time: - | ${Utils.getTimeString(endTime - startTime)} - | read_time: ${Utils.getTimeString(startUploadTime - startTime)} - | upload_time: ${Utils.getTimeString(endTime - startUploadTime)} - |""".stripMargin.filter(_ >= ' ') - } - // When attempt number is smaller than 2, throw exception if (TaskContext.get().attemptNumber() < 2) { TestHook.raiseExceptionIfTestFlagEnabled( From 793aa21f188c9e401acc51787dc8f0509f9402d0 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 30 Oct 2024 12:03:19 -0700 Subject: [PATCH 4/8] support null value in array --- .../spark/snowflake/ParquetSuite.scala | 35 +++++++++++++++++++ .../snowflake/io/CloudStorageOperations.scala | 4 +++ 2 files changed, 39 insertions(+) diff --git a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala index 81ffc29e..398ee41a 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala @@ -536,6 +536,41 @@ class ParquetSuite extends IntegrationSuiteBase { } } + test("null value in array") { + val data: RDD[Row] = sc.makeRDD( + List( + Row( + Array(null, "one", "two", "three"), + ), + Row( + Array("one", null, "two", "three"), + ) + ) + ) + + val schema = StructType(List( + StructField("ARRAY_STRING_FIELD", + ArrayType(StringType, containsNull = true), nullable = true))) + val df = sparkSession.createDataFrame(data, schema) + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_array_map) + .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") + .mode(SaveMode.Overwrite) + .save() + + + val res = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_array_map) + .schema(schema) + .load().collect() + assert(res.head.getSeq(0) == Seq("null", "one", "two", "three")) + assert(res(1).getSeq(0) == Seq("one", "null", "two", "three")) + } + test("test error when column map does not match") { jdbcUpdate(s"create or replace table $test_column_map_not_match (num int, str string)") // auto map diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala index 18f4e2e1..23f31b93 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala @@ -46,6 +46,7 @@ import net.snowflake.spark.snowflake.test.{TestHook, TestHookFlag} import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration import org.apache.parquet.avro.AvroParquetWriter import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.{SparkContext, TaskContext} @@ -674,9 +675,12 @@ sealed trait CloudStorage { format match { case SupportedFormat.PARQUET => val rows = input.asInstanceOf[Iterator[GenericData.Record]].toSeq + val config = new Configuration() + config.setBoolean("parquet.avro.write-old-list-structure", false) val writer = AvroParquetWriter.builder[GenericData.Record]( new ParquetUtils.StreamOutputFile(uploadStream) ).withSchema(rows.head.getSchema) + .withConf(config) .withCompressionCodec(CompressionCodecName.SNAPPY) .build() rows.foreach(writer.write) From a162c7d297c1a7d8b483c4f15ccc6c29e7ba6ab0 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 30 Oct 2024 13:56:06 -0700 Subject: [PATCH 5/8] refactor cloud operations --- .../spark/snowflake/io/StageSuite.scala | 15 ++---- .../snowflake/io/CloudStorageOperations.scala | 53 +++++-------------- 2 files changed, 17 insertions(+), 51 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/io/StageSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/io/StageSuite.scala index 1a310c75..e1d80cf3 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/io/StageSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/io/StageSuite.scala @@ -311,16 +311,13 @@ class StageSuite extends IntegrationSuiteBase { try { // The credential for the external stage is fake. val azureExternalStage = ExternalAzureStorage( + param, containerName = "test_fake_container", azureAccount = "test_fake_account", azureEndpoint = "blob.core.windows.net", azureSAS = "?sig=test_test_test_test_test_test_test_test_test_test_test_test" + "_test_test_test_test_test_fak&spr=https&sp=rwdl&sr=c", - param.proxyInfo, - param.maxRetryCount, - param.sfURL, - param.useExponentialBackoff, param.expectedPartitionCount, pref = "test_dir", connection = connection @@ -367,13 +364,10 @@ class StageSuite extends IntegrationSuiteBase { try { // The credential for the external stage is fake. val s3ExternalStage = ExternalS3Storage( + param, bucketName = "test_fake_bucket", awsId = "TEST_TEST_TEST_TEST1", awsKey = "TEST_TEST_TEST_TEST_TEST_TEST_TEST_TEST2", - param.proxyInfo, - param.maxRetryCount, - param.sfURL, - param.useExponentialBackoff, param.expectedPartitionCount, pref = "test_dir", connection = connection, @@ -487,13 +481,10 @@ class StageSuite extends IntegrationSuiteBase { try { // The credential for the external stage is fake. val s3ExternalStage = ExternalS3Storage( + param, bucketName = "test_fake_bucket", awsId = "TEST_TEST_TEST_TEST1", awsKey = "TEST_TEST_TEST_TEST_TEST_TEST_TEST_TEST2", - param.proxyInfo, - param.maxRetryCount, - param.sfURL, - param.useExponentialBackoff, param.expectedPartitionCount, pref = "test_dir", connection = connection, diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala index 23f31b93..bfc3c83a 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala @@ -290,14 +290,11 @@ object CloudStorageOperations { ( ExternalAzureStorage( + param = param, containerName = container, azureAccount = account, azureEndpoint = endpoint, azureSAS = azureSAS, - param.proxyInfo, - param.maxRetryCount, - param.sfURL, - param.useExponentialBackoff, param.expectedPartitionCount, pref = path, connection = conn @@ -321,13 +318,10 @@ object CloudStorageOperations { ( ExternalS3Storage( + param = param, bucketName = bucket, awsId = param.awsAccessKey.get, awsKey = param.awsSecretKey.get, - param.proxyInfo, - param.maxRetryCount, - param.sfURL, - param.useExponentialBackoff, param.expectedPartitionCount, pref = prefix, connection = conn, @@ -495,14 +489,15 @@ private[io] object StorageInfo { } sealed trait CloudStorage { + protected val param: MergedParameters protected val RETRY_SLEEP_TIME_UNIT_IN_MS: Int = 1500 protected val MAX_SLEEP_TIME_IN_MS: Int = 3 * 60 * 1000 private var processedFileCount = 0 protected val connection: ServerConnection - protected val maxRetryCount: Int - protected val proxyInfo: Option[ProxyInfo] - protected val sfURL: String - protected val useExponentialBackoff: Boolean + protected val maxRetryCount: Int = param.maxRetryCount + protected val proxyInfo: Option[ProxyInfo] = param.proxyInfo + protected val sfURL: String = param.sfURL + protected val useExponentialBackoff: Boolean = param.useExponentialBackoff // The first 10 sleep time in second will be like // 3, 6, 12, 24, 48, 96, 192, 300, 300, 300, etc @@ -921,15 +916,11 @@ sealed trait CloudStorage { def fileExists(fileName: String): Boolean } -case class InternalAzureStorage(param: MergedParameters, +case class InternalAzureStorage(override protected val param: MergedParameters, stageName: String, @transient override val connection: ServerConnection) extends CloudStorage { - override val maxRetryCount = param.maxRetryCount - override val proxyInfo: Option[ProxyInfo] = param.proxyInfo - override val sfURL = param.sfURL - override val useExponentialBackoff = param.useExponentialBackoff override protected def getStageInfo( isWrite: Boolean, @@ -1150,14 +1141,11 @@ case class InternalAzureStorage(param: MergedParameters, } } -case class ExternalAzureStorage(containerName: String, +case class ExternalAzureStorage(override protected val param: MergedParameters, + containerName: String, azureAccount: String, azureEndpoint: String, azureSAS: String, - override val proxyInfo: Option[ProxyInfo], - override val maxRetryCount: Int, - override val sfURL: String, - override val useExponentialBackoff: Boolean, fileCountPerPartition: Int, pref: String = "", @transient override val connection: ServerConnection) @@ -1302,16 +1290,12 @@ case class ExternalAzureStorage(containerName: String, } } -case class InternalS3Storage(param: MergedParameters, +case class InternalS3Storage(override protected val param: MergedParameters, stageName: String, @transient override val connection: ServerConnection, parallelism: Int = CloudStorageOperations.DEFAULT_PARALLELISM) extends CloudStorage { - override val maxRetryCount = param.maxRetryCount - override val proxyInfo: Option[ProxyInfo] = param.proxyInfo - override val sfURL = param.sfURL - override val useExponentialBackoff = param.useExponentialBackoff override protected def getStageInfo( isWrite: Boolean, @@ -1550,13 +1534,10 @@ case class InternalS3Storage(param: MergedParameters, } } -case class ExternalS3Storage(bucketName: String, +case class ExternalS3Storage(override protected val param: MergedParameters, + bucketName: String, awsId: String, awsKey: String, - override val proxyInfo: Option[ProxyInfo], - override val maxRetryCount: Int, - override val sfURL: String, - override val useExponentialBackoff: Boolean, fileCountPerPartition: Int, awsToken: Option[String] = None, pref: String = "", @@ -1704,18 +1685,12 @@ case class ExternalS3Storage(bucketName: String, // Internal CloudStorage for GCS (Google Cloud Storage). // NOTE: External storage for GCS is not supported. -case class InternalGcsStorage(param: MergedParameters, +case class InternalGcsStorage(override protected val param: MergedParameters, stageName: String, @transient override val connection: ServerConnection, @transient stageManager: SFInternalStage) extends CloudStorage { - override val proxyInfo: Option[ProxyInfo] = param.proxyInfo - // Max retry count to upload a file - override val maxRetryCount: Int = param.maxRetryCount - override val sfURL = param.sfURL - override val useExponentialBackoff = param.useExponentialBackoff - // Generate file transfer metadata objects for file upload. On GCS, // the file transfer metadata is pre-signed URL and related metadata. // This function is called on Master node. From 601828294cb8b2a7d40c6f5e05609d43573cbbaa Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 30 Oct 2024 14:28:05 -0700 Subject: [PATCH 6/8] move parquet writer --- .../spark/snowflake/SnowflakeWriter.scala | 51 +------------------ .../snowflake/io/CloudStorageOperations.scala | 42 +++++++-------- .../spark/snowflake/io/ParquetUtils.scala | 40 +++++++++++++++ .../spark/snowflake/io/StageWriter.scala | 2 +- .../spark/snowflake/io/package.scala | 3 +- 5 files changed, 62 insertions(+), 76 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala index e37ef5fb..14e94eb3 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala @@ -17,27 +17,15 @@ package net.snowflake.spark.snowflake -import scala.collection.JavaConverters._ import java.sql.{Date, Timestamp} import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64 import net.snowflake.spark.snowflake.DefaultJDBCWrapper.{snowflakeStyleSchema, snowflakeStyleString} -import net.snowflake.spark.snowflake.Parameters.{MergedParameters, mergeParameters} -import net.snowflake.spark.snowflake.SparkConnectorContext.getClass -import net.snowflake.spark.snowflake.Utils.ensureUnquoted +import net.snowflake.spark.snowflake.Parameters.MergedParameters import net.snowflake.spark.snowflake.io.SupportedFormat import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat -import org.apache.avro.Schema -import org.apache.avro.generic.GenericData import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.RebaseDateTime -import org.slf4j.LoggerFactory - -import java.nio.ByteBuffer -import java.time.{LocalDate, ZoneId, ZoneOffset} -import java.util.concurrent.TimeUnit -import scala.collection.mutable /** * Functions to write data to Snowflake. @@ -198,42 +186,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { format match { case SupportedFormat.PARQUET => val snowflakeStyleSchema = mapColumn(data.schema, params, snowflakeStyle = true) - val schema = io.ParquetUtils.convertStructToAvro(snowflakeStyleSchema) - (data.rdd.map (row => { - def rowToAvroRecord(row: Row, - schema: Schema, - snowflakeStyleSchema: StructType, - params: MergedParameters): GenericData.Record = { - val record = new GenericData.Record(schema) - row.toSeq.zip(snowflakeStyleSchema.names).foreach { - case (row: Row, name) => - record.put(name, - rowToAvroRecord( - row, - schema.getField(name).schema().getTypes.get(0), - snowflakeStyleSchema(name).dataType.asInstanceOf[StructType], - params - )) - case (map: scala.collection.immutable.Map[Any, Any], name) => - record.put(name, map.asJava) - case (str: String, name) => - record.put(name, if (params.trimSpace) str.trim else str) - case (arr: mutable.WrappedArray[Any], name) => - record.put(name, arr.toArray) - case (decimal: java.math.BigDecimal, name) => - record.put(name, ByteBuffer.wrap(decimal.unscaledValue().toByteArray)) - case (timestamp: java.sql.Timestamp, name) => - record.put(name, timestamp.toString) - case (date: java.sql.Date, name) => - record.put(name, date.toString) - case (date: java.time.LocalDateTime, name) => - record.put(name, date.toEpochSecond(ZoneOffset.UTC)) - case (value, name) => record.put(name, value) - } - record - } - rowToAvroRecord(row, schema, snowflakeStyleSchema, params) - }), snowflakeStyleSchema) + (data.rdd.asInstanceOf[RDD[Any]], snowflakeStyleSchema) case SupportedFormat.CSV => val conversionFunction = genConversionFunctions(data.schema, params) (data.rdd.map(row => { diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala index bfc3c83a..0605bcc7 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala @@ -51,6 +51,8 @@ import org.apache.parquet.avro.AvroParquetWriter import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType import org.slf4j.{Logger, LoggerFactory} import scala.util.Random @@ -346,22 +348,6 @@ object CloudStorageOperations { } } - /** - * Save a string rdd to cloud storage - * - * @param data data frame object - * @param storage storage client - * @return a list of file name - */ - def saveToStorage( - data: RDD[Any], - format: SupportedFormat = SupportedFormat.CSV, - dir: Option[String] = None, - compress: Boolean = true - )(implicit storage: CloudStorage): List[String] = { - storage.upload(data, format, dir, compress).map(_.fileName) - } - def deleteFiles(files: List[String])(implicit storage: CloudStorage, connection: ServerConnection): Unit = storage.deleteFiles(files) @@ -530,9 +516,10 @@ sealed trait CloudStorage { def upload(data: RDD[Any], format: SupportedFormat = SupportedFormat.CSV, + schema: StructType, dir: Option[String], compress: Boolean = true): List[FileUploadResult] = - uploadRDD(data, format, dir, compress, getStageInfo(isWrite = true)._1) + uploadRDD(data, format, schema, dir, compress, getStageInfo(isWrite = true)._1) private[io] def checkUploadMetadata(storageInfo: Option[Map[String, String]], fileTransferMetadata: Option[SnowflakeFileTransferMetadata] @@ -562,6 +549,7 @@ sealed trait CloudStorage { // sleep time based on the task's attempt number. protected def uploadPartition(rows: Iterator[Any], format: SupportedFormat, + schema: StructType, compress: Boolean, directory: String, partitionID: Int, @@ -576,7 +564,7 @@ sealed trait CloudStorage { try { // Read data and upload to cloud storage - doUploadPartition(rows, format, compress, directory, partitionID, + doUploadPartition(rows, format, schema, compress, directory, partitionID, storageInfo, fileTransferMetadata) } catch { // Hit exception when uploading the file @@ -633,6 +621,7 @@ sealed trait CloudStorage { // Read data and upload to cloud storage private def doUploadPartition(input: Iterator[Any], format: SupportedFormat, + schema: StructType, compress: Boolean, directory: String, partitionID: Int, @@ -669,17 +658,20 @@ sealed trait CloudStorage { try { format match { case SupportedFormat.PARQUET => - val rows = input.asInstanceOf[Iterator[GenericData.Record]].toSeq + val avroSchema = io.ParquetUtils.convertStructToAvro(schema) val config = new Configuration() config.setBoolean("parquet.avro.write-old-list-structure", false) val writer = AvroParquetWriter.builder[GenericData.Record]( new ParquetUtils.StreamOutputFile(uploadStream) - ).withSchema(rows.head.getSchema) + ).withSchema(avroSchema) .withConf(config) .withCompressionCodec(CompressionCodecName.SNAPPY) .build() - rows.foreach(writer.write) - rowCount += rows.size + input.foreach { + case row: Row => + writer.write(ParquetUtils.rowToAvroRecord(row, avroSchema, schema, param)) + rowCount += 1 + } dataSize += writer.getDataSize writer.close() case _ => @@ -749,6 +741,7 @@ sealed trait CloudStorage { protected def uploadRDD(data: RDD[Any], format: SupportedFormat = SupportedFormat.CSV, + schema: StructType, dir: Option[String], compress: Boolean = true, storageInfo: Map[String, String]): List[FileUploadResult] = { @@ -779,7 +772,7 @@ sealed trait CloudStorage { SparkConnectorContext.recordConfig() // Convert and upload the partition with the StorageInfo - uploadPartition(rows, format, compress, directory, index, Some(storageInfo), None) + uploadPartition(rows, format, schema, compress, directory, index, Some(storageInfo), None) /////////////////////////////////////////////////////////////////////// // End code snippet to be executed on worker @@ -1765,6 +1758,7 @@ case class InternalGcsStorage(override protected val param: MergedParameters, // so override it separately. override def upload(data: RDD[Any], format: SupportedFormat = SupportedFormat.CSV, + schema: StructType, dir: Option[String], compress: Boolean = true): List[FileUploadResult] = { @@ -1808,7 +1802,7 @@ case class InternalGcsStorage(override protected val param: MergedParameters, metadatas.head } // Convert and upload the partition with the file transfer metadata - uploadPartition(rows, format, compress, directory, index, None, Some(metadata)) + uploadPartition(rows, format, schema, compress, directory, index, None, Some(metadata)) /////////////////////////////////////////////////////////////////////// // End code snippet to executed on worker diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala b/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala index ca7b9990..b860ec69 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala @@ -1,15 +1,55 @@ package net.snowflake.spark.snowflake.io +import net.snowflake.spark.snowflake.Parameters.MergedParameters import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.SchemaBuilder.{BaseFieldTypeBuilder, BaseTypeBuilder, FieldDefault, RecordBuilder} +import org.apache.avro.generic.GenericData import org.apache.parquet.io.{OutputFile, PositionOutputStream} +import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import java.io.OutputStream +import java.nio.ByteBuffer +import java.time.ZoneOffset +import scala.collection.mutable +import scala.collection.JavaConverters._ object ParquetUtils { private val nameSpace = "snowflake" + def rowToAvroRecord(row: Row, + schema: Schema, + snowflakeStyleSchema: StructType, + params: MergedParameters): GenericData.Record = { + val record = new GenericData.Record(schema) + row.toSeq.zip(snowflakeStyleSchema.names).foreach { + case (row: Row, name) => + record.put(name, + rowToAvroRecord( + row, + schema.getField(name).schema().getTypes.get(0), + snowflakeStyleSchema(name).dataType.asInstanceOf[StructType], + params + )) + case (map: scala.collection.immutable.Map[Any, Any], name) => + record.put(name, map.asJava) + case (str: String, name) => + record.put(name, if (params.trimSpace) str.trim else str) + case (arr: mutable.WrappedArray[Any], name) => + record.put(name, arr.toArray) + case (decimal: java.math.BigDecimal, name) => + record.put(name, ByteBuffer.wrap(decimal.unscaledValue().toByteArray)) + case (timestamp: java.sql.Timestamp, name) => + record.put(name, timestamp.toString) + case (date: java.sql.Date, name) => + record.put(name, date.toString) + case (date: java.time.LocalDateTime, name) => + record.put(name, date.toEpochSecond(ZoneOffset.UTC)) + case (value, name) => record.put(name, value) + } + record + } + def convertStructToAvro(structType: StructType): Schema = convertStructToAvro( structType, diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index 8ed35920..d0a8e1a8 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -216,7 +216,7 @@ private[io] object StageWriter { params, conn, tempStage = true, None, "load") val startTime = System.currentTimeMillis() - val fileUploadResults = storage.upload(rdd, format, None) + val fileUploadResults = storage.upload(rdd, format, schema, None) val startCopyInto = System.currentTimeMillis() if (fileUploadResults.nonEmpty) { diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/package.scala b/src/main/scala/net/snowflake/spark/snowflake/io/package.scala index 86961528..f4e46704 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/package.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/package.scala @@ -46,8 +46,7 @@ package object io { rdd: RDD[Any], schema: StructType, saveMode: SaveMode, - format: SupportedFormat = SupportedFormat.CSV, - mapper: Option[Map[String, String]] = None): Unit = + format: SupportedFormat = SupportedFormat.CSV): Unit = StageWriter.writeToStage(sqlContext, rdd, schema, saveMode, params, format) } From 4bc772aa6536e547b6f0d5183a8bf636c19dd2ff Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 1 Nov 2024 12:28:06 -0700 Subject: [PATCH 7/8] fix file size --- .../snowflake/io/CloudStorageOperations.scala | 19 ++++++++++++++++--- .../spark/snowflake/io/StageWriter.scala | 5 +++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala index 0605bcc7..a93d5f9b 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala @@ -43,6 +43,7 @@ import net.snowflake.spark.snowflake.Parameters.MergedParameters import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations import net.snowflake.spark.snowflake.test.{TestHook, TestHookFlag} +import org.apache.avro.Schema import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.commons.io.IOUtils @@ -485,6 +486,8 @@ sealed trait CloudStorage { protected val sfURL: String = param.sfURL protected val useExponentialBackoff: Boolean = param.useExponentialBackoff + protected var avroSchema: Option[String] = None + // The first 10 sleep time in second will be like // 3, 6, 12, 24, 48, 96, 192, 300, 300, 300, etc protected def retrySleepTimeInMS(retry: Int): Int = { @@ -658,7 +661,7 @@ sealed trait CloudStorage { try { format match { case SupportedFormat.PARQUET => - val avroSchema = io.ParquetUtils.convertStructToAvro(schema) + val avroSchema = new Schema.Parser().parse(this.avroSchema.get) val config = new Configuration() config.setBoolean("parquet.avro.write-old-list-structure", false) val writer = AvroParquetWriter.builder[GenericData.Record]( @@ -672,7 +675,6 @@ sealed trait CloudStorage { writer.write(ParquetUtils.rowToAvroRecord(row, avroSchema, schema, param)) rowCount += 1 } - dataSize += writer.getDataSize writer.close() case _ => val rows = input.asInstanceOf[Iterator[String]] @@ -728,11 +730,12 @@ sealed trait CloudStorage { ) } + val dataSizeStr = if (dataSize == 0) "N/A" else Utils.getSizeString(dataSize) CloudStorageOperations.log.info( s"""${SnowflakeResultSetRDD.WORKER_LOG_PREFIX}: | Finish writing partition ID:$partitionID $fileName | write row count is $rowCount. - | Uncompressed data size is ${Utils.getSizeString(dataSize)}. + | Uncompressed data size is $dataSizeStr. | $processTimeInfo |""".stripMargin.filter(_ >= ' ')) @@ -758,6 +761,16 @@ sealed trait CloudStorage { | partitions: directory=$directory ${format.toString} $compress |""".stripMargin.filter(_ >= ' ')) + // 1. Avro Schema is not serializable in Spark 3.1.1. + // 2. Somehow, the Avro Schema can be created only one time with Schema builder. + // Therefore, if we create schema in the mapPartition function, we will get some error. + // e.g. cannot process decimal data. + // Alternatively, we create schema only one time here, and serialize the Json string to + // each partition, and then deserialize the Json string to avro schema in the partition. + if (format == SupportedFormat.PARQUET) { + this.avroSchema = Some(io.ParquetUtils.convertStructToAvro(schema).toString()) + } + // Some explain for newbies on spark connector: // Bellow code is executed in distributed by spark FRAMEWORK // 1. The master node executes "data.mapPartitionsWithIndex()" diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index d0a8e1a8..cc23f314 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -562,10 +562,11 @@ private[io] object StageWriter { totalSize += fileUploadResult.fileSize totalRowCount += fileUploadResult.rowCount }) + val fileSizeStr = if (totalSize == 0) "N/A" else Utils.getSizeString(totalSize) logAndAppend(progress, s"Total file count is ${fileUploadResults.size}, " + s"non-empty files count is ${expectedFileSet.size}, " + - s"total file size is ${Utils.getSizeString(totalSize)}, " + - s"total row count is ${Utils.getSizeString(totalRowCount)}.") + s"total file size is $fileSizeStr, " + + s"total row count is $totalRowCount.") // Indicate whether to use FILES clause in the copy command var useFilesClause = false From de48a9d55a9060c311913425c1dc02d306fe0c77 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 1 Nov 2024 17:06:15 -0700 Subject: [PATCH 8/8] fix gcs --- .../spark/snowflake/io/CloudStorageOperations.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala index a93d5f9b..5f551c96 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala @@ -1795,6 +1795,16 @@ case class InternalGcsStorage(override protected val param: MergedParameters, // If the partition count is 0, no metadata is created. val oneMetadataPerFile = metadatas.nonEmpty && metadatas.head.isForOneFile + // 1. Avro Schema is not serializable in Spark 3.1.1. + // 2. Somehow, the Avro Schema can be created only one time with Schema builder. + // Therefore, if we create schema in the mapPartition function, we will get some error. + // e.g. cannot process decimal data. + // Alternatively, we create schema only one time here, and serialize the Json string to + // each partition, and then deserialize the Json string to avro schema in the partition. + if (format == SupportedFormat.PARQUET) { + this.avroSchema = Some(io.ParquetUtils.convertStructToAvro(schema).toString()) + } + // Some explain for newbies on spark connector: // Bellow code is executed in distributed by spark FRAMEWORK // 1. The master node executes "data.mapPartitionsWithIndex()"