diff --git a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLChecker.kt b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLChecker.kt index 63f946721cb1e..5f7b8c2aecb97 100644 --- a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLChecker.kt +++ b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLChecker.kt @@ -11,11 +11,6 @@ import jakarta.inject.Singleton import java.util.UUID import javax.sql.DataSource -const val CHECK_TABLE_STATEMENT = """ - CREATE TABLE ? (test int); - DROP TABLE ?; -""" - @Singleton class MSSQLChecker(private val dataSourceFactory: MSSQLDataSourceFactory) : DestinationChecker { @@ -24,10 +19,13 @@ class MSSQLChecker(private val dataSourceFactory: MSSQLDataSourceFactory) : val testTableName = "check_test_${UUID.randomUUID()}" val fullyQualifiedTableName = "[${config.rawDataSchema}].[${testTableName}]" dataSource.connection.use { connection -> - connection.prepareStatement(CHECK_TABLE_STATEMENT.trimIndent()).use { statement -> - statement.setString(1, fullyQualifiedTableName) - statement.setString(2, fullyQualifiedTableName) - statement.executeUpdate() + connection.createStatement().use { statement -> + statement.executeUpdate( + """ + CREATE TABLE ${fullyQualifiedTableName} (test int); + DROP TABLE ${fullyQualifiedTableName}; + """.trimIndent(), + ) } } } diff --git a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLQueryBuilder.kt b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLQueryBuilder.kt index a2ecb7e46a124..e1bad22fa4210 100644 --- a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLQueryBuilder.kt +++ b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLQueryBuilder.kt @@ -4,7 +4,10 @@ package io.airbyte.integrations.destination.mssql.v2 +import io.airbyte.cdk.load.command.Append +import io.airbyte.cdk.load.command.Dedupe import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.command.Overwrite import io.airbyte.cdk.load.data.AirbyteType import io.airbyte.cdk.load.data.AirbyteValue import io.airbyte.cdk.load.data.FieldType @@ -25,28 +28,84 @@ import io.airbyte.protocol.models.Jsons import io.airbyte.protocol.models.v0.AirbyteRecordMessageMeta import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange import java.lang.ArithmeticException +import java.sql.Connection import java.sql.PreparedStatement import java.sql.ResultSet import java.util.UUID +fun String.executeQuery(connection: Connection, vararg args: String, f: (ResultSet) -> T): T { + connection.prepareStatement(this.trimIndent()).use { statement -> + args.forEachIndexed { index, arg -> statement.setString(index + 1, arg) } + return statement.executeQuery().use(f) + } +} + +fun String.executeUpdate(connection: Connection, vararg args: String) { + connection.prepareStatement(this.trimIndent()).use { statement -> + args.forEachIndexed { index, arg -> statement.setString(index + 1, arg) } + statement.executeUpdate() + } +} + +fun String.toQuery(vararg args: String): String = this.trimIndent().replace("?", "%s").format(*args) + const val GET_EXISTING_SCHEMA_QUERY = """ - SELECT COLUMN_NAME, DATA_TYPE - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? - ORDER BY ORDINAL_POSITION ASC - """ + SELECT COLUMN_NAME, DATA_TYPE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY ORDINAL_POSITION ASC + """ + +const val CREATE_SCHEMA_QUERY = + """ + DECLARE @Schema VARCHAR(MAX) = ? + IF NOT EXISTS (SELECT name FROM sys.schemas WHERE name = @Schema) + BEGIN + EXEC ('CREATE SCHEMA ' + @Schema); + END + """ + +const val ALTER_TABLE_ADD = """ + ALTER TABLE ? + ADD ? ? NULL; + """ +const val ALTER_TABLE_DROP = """ + ALTER TABLE ? + DROP COLUMN ?; + """ +const val ALTER_TABLE_MODIFY = """ + ALTER TABLE ? + ALTER COLUMN ? ? NULL; + """ + +const val DELETE_WHERE_COL_IS_NOT_NULL = + """ + DELETE FROM ? + WHERE ? is not NULL + """ + +const val DELETE_WHERE_COL_LESS_THAN = """ + DELETE FROM ? + WHERE ? < ? + """ + +const val SELECT_FROM = """ + SELECT * + FROM ? + """ class MSSQLQueryBuilder( config: MSSQLConfiguration, private val stream: DestinationStream, ) { - companion object { + const val AIRBYTE_RAW_ID = "_airbyte_raw_id" const val AIRBYTE_EXTRACTED_AT = "_airbyte_extracted_at" const val AIRBYTE_META = "_airbyte_meta" const val AIRBYTE_GENERATION_ID = "_airbyte_generation_id" + const val AIRBYTE_CDC_DELETED_AT = "_ab_cdc_deleted_at" const val DEFAULT_SEPARATOR = ",\n " val airbyteFinalTableFields = @@ -79,19 +138,29 @@ class MSSQLQueryBuilder( private val outputSchema: String = stream.descriptor.namespace ?: config.schema private val tableName: String = stream.descriptor.name - private val fqTableName = "$outputSchema.$tableName" + val fqTableName = "$outputSchema.$tableName" + private val uniquenessKey: List = + when (stream.importType) { + is Dedupe -> + if ((stream.importType as Dedupe).primaryKey.isNotEmpty()) { + (stream.importType as Dedupe).primaryKey.map { it.joinToString(".") } + } else { + listOf((stream.importType as Dedupe).cursor.joinToString(".")) + } + Append -> emptyList() + Overwrite -> emptyList() + } private val toSqlType = AirbyteTypeToSqlType() private val toMssqlType = SqlTypeToMssqlType() val finalTableSchema: List = airbyteFinalTableFields + extractFinalTableSchema(stream.schema) + val hasCdc: Boolean = finalTableSchema.any { it.name == AIRBYTE_CDC_DELETED_AT } - fun getExistingSchema(statement: PreparedStatement): List { + private fun getExistingSchema(connection: Connection): List { val fields = mutableListOf() - statement.setString(1, outputSchema) - statement.setString(2, tableName) - statement.executeQuery().use { rs -> + GET_EXISTING_SCHEMA_QUERY.executeQuery(connection, outputSchema, tableName) { rs -> while (rs.next()) { val name = rs.getString("COLUMN_NAME") val type = MssqlType.valueOf(rs.getString("DATA_TYPE").uppercase()) @@ -101,53 +170,69 @@ class MSSQLQueryBuilder( return fields } - fun getSchema(): List = + private fun getSchema(): List = finalTableSchema.map { NamedSqlField(it.name, toMssqlType.convert(toSqlType.convert(it.type.type))) } - fun alterTableIfNeeded( - existingSchema: List, - expectedSchema: List, - ): String? { + fun updateSchema(connection: Connection) { + val existingSchema = getExistingSchema(connection) + val expectedSchema = getSchema() + val existingFields = existingSchema.associate { it.name to it.type } val expectedFields = expectedSchema.associate { it.name to it.type } if (existingFields == expectedFields) { - return null + return } + val toDelete = existingFields.filter { it.key !in expectedFields } val toAdd = expectedFields.filter { it.key !in existingFields } val toAlter = expectedFields.filter { it.key in existingFields && it.value != existingFields[it.key] } - return StringBuilder() - .apply { - toDelete.entries.forEach { - appendLine("ALTER TABLE $fqTableName") - appendLine("DROP COLUMN ${it.key};") - } - toAdd.entries.forEach { - appendLine("ALTER TABLE $fqTableName") - appendLine("ADD ${it.key} ${it.value.sqlString} NULL;") - } - toAlter.entries.forEach { - appendLine("ALTER TABLE $fqTableName") - appendLine("ALTER COLUMN ${it.key} ${it.value.sqlString} NULL;") + + val query = + StringBuilder() + .apply { + toDelete.entries.forEach { + appendLine(ALTER_TABLE_DROP.toQuery(fqTableName, it.key)) + } + toAdd.entries.forEach { + appendLine(ALTER_TABLE_ADD.toQuery(fqTableName, it.key, it.value.sqlString)) + } + toAlter.entries.forEach { + appendLine( + ALTER_TABLE_MODIFY.toQuery(fqTableName, it.key, it.value.sqlString) + ) + } } - } - .toString() + .toString() + + query.executeUpdate(connection) } - fun createFinalTableIfNotExists(): String = - createTableIfNotExists(fqTableName, finalTableSchema) + fun createTableIfNotExists(connection: Connection) { + CREATE_SCHEMA_QUERY.executeUpdate(connection, outputSchema) - fun createFinalSchemaIfNotExists(): String = createSchemaIfNotExists(outputSchema) + connection.createStatement().use { + it.executeUpdate(createTableIfNotExists(fqTableName, finalTableSchema)) + } + } fun getFinalTableInsertColumnHeader(): String = getFinalTableInsertColumnHeader(fqTableName, finalTableSchema) - fun deletePreviousGenerations(minGenerationId: Long): String = - deleteWhere(fqTableName, minGenerationId) + fun deleteCdc(connection: Connection) = + DELETE_WHERE_COL_IS_NOT_NULL.toQuery(fqTableName, AIRBYTE_CDC_DELETED_AT) + .executeUpdate(connection) + + fun deletePreviousGenerations(connection: Connection, minGenerationId: Long) = + DELETE_WHERE_COL_LESS_THAN.toQuery( + fqTableName, + AIRBYTE_GENERATION_ID, + minGenerationId.toString() + ) + .executeUpdate(connection) fun populateStatement( statement: PreparedStatement, @@ -216,46 +301,63 @@ class MSSQLQueryBuilder( return ObjectValue.from(valueMap) } - fun selectAllRecords(): String = "SELECT * FROM $fqTableName" - - private fun createSchemaIfNotExists(schema: String): String = - """ - IF NOT EXISTS (SELECT name FROM sys.schemas WHERE name = '$schema') - BEGIN - EXEC ('CREATE SCHEMA $schema'); - END - """.trimIndent() + private fun createTableIfNotExists(fqTableName: String, schema: List): String { + val index = + if (uniquenessKey.isNotEmpty()) + createIndex(fqTableName, uniquenessKey, clustered = false) + else "" + val cdcIndex = if (hasCdc) createIndex(fqTableName, listOf(AIRBYTE_CDC_DELETED_AT)) else "" - private fun createTableIfNotExists(fqTableName: String, schema: List): String = - """ + return """ IF OBJECT_ID('$fqTableName') IS NULL BEGIN CREATE TABLE $fqTableName ( ${airbyteTypeToSqlSchema(schema)} ); + $index; + $cdcIndex; END """.trimIndent() + } - private fun deleteWhere(fqTableName: String, minGenerationId: Long) = - """ - DELETE FROM $fqTableName - WHERE $AIRBYTE_GENERATION_ID < $minGenerationId - """.trimIndent() + private fun createIndex( + fqTableName: String, + columns: List, + clustered: Boolean = false + ): String { + val name = "${fqTableName.replace('.', '_')}_${columns.hashCode()}" + val indexType = if (clustered) "CLUSTERED" else "" + return "CREATE $indexType INDEX $name ON $fqTableName (${columns.joinToString(", ")})" + } private fun getFinalTableInsertColumnHeader( fqTableName: String, schema: List ): String { - return StringBuilder() - .apply { - append("INSERT INTO $fqTableName(") - append(schema.map { it.name }.joinToString(", ")) - append(") VALUES (") - append(schema.map { "?" }.joinToString(", ")) - append(")") - } - .toString() + val columns = schema.joinToString(", ") { it.name } + val templateColumns = schema.joinToString(", ") { "?" } + return if (uniquenessKey.isEmpty()) { + """ + INSERT INTO $fqTableName ($columns) + SELECT table_value.* + FROM (VALUES ($templateColumns)) table_value($columns) + """ + } else { + val uniquenessConstraint = + uniquenessKey.joinToString(" AND ") { "Target.$it = Source.$it" } + val updateStatement = schema.joinToString(", ") { "${it.name} = Source.${it.name}" } + """ + MERGE INTO $fqTableName AS Target + USING (VALUES ($templateColumns)) AS Source ($columns) + ON $uniquenessConstraint + WHEN MATCHED THEN + UPDATE SET $updateStatement + WHEN NOT MATCHED BY TARGET THEN + INSERT ($columns) VALUES ($columns) + ; + """.trimIndent() + } } private fun extractFinalTableSchema(schema: AirbyteType): List = diff --git a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLStreamLoader.kt b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLStreamLoader.kt index de6474c56af39..9e46466c2b826 100644 --- a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLStreamLoader.kt +++ b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLStreamLoader.kt @@ -45,6 +45,9 @@ class MSSQLStreamLoader( statement.addBatch() } statement.executeLargeBatch() + if (sqlBuilder.hasCdc) { + sqlBuilder.deleteCdc(connection) + } connection.commit() } return SimpleBatch(Batch.State.COMPLETE) @@ -52,29 +55,9 @@ class MSSQLStreamLoader( private fun ensureTableExists(dataSource: DataSource) { try { - // TODO leverage preparedStatement instead of createStatement dataSource.connection.use { connection -> - connection.createStatement().use { statement -> - statement.executeUpdate(sqlBuilder.createFinalSchemaIfNotExists()) - } - connection.createStatement().use { statement -> - statement.executeUpdate(sqlBuilder.createFinalTableIfNotExists()) - } - val alterStatement = - connection.prepareStatement(GET_EXISTING_SCHEMA_QUERY.trimIndent()).use { - statement -> - val existingSchema = sqlBuilder.getExistingSchema(statement) - val expectedSchema = sqlBuilder.getSchema() - sqlBuilder.alterTableIfNeeded( - existingSchema = existingSchema, - expectedSchema = expectedSchema, - ) - } - alterStatement?.let { - connection.createStatement().use { statement -> - statement.executeUpdate(alterStatement) - } - } + sqlBuilder.createTableIfNotExists(connection) + sqlBuilder.updateSchema(connection) } } catch (ex: Exception) { log.error(ex) { ex.message } @@ -85,11 +68,7 @@ class MSSQLStreamLoader( private fun truncatePreviousGenerations(dataSource: DataSource) { // TODO this can be improved to avoid attempting to truncate the data for each sync dataSource.connection.use { connection -> - connection.createStatement().use { statement -> - statement.executeUpdate( - sqlBuilder.deletePreviousGenerations(stream.minimumGenerationId) - ) - } + sqlBuilder.deletePreviousGenerations(connection, stream.minimumGenerationId) } } } diff --git a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/convert/AirbyteValueToStatement.kt b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/convert/AirbyteValueToStatement.kt index 1eddb065486a0..30a0ddc57ac7b 100644 --- a/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/convert/AirbyteValueToStatement.kt +++ b/airbyte-integrations/connectors/destination-mssql-v2/src/main/kotlin/io/airbyte/integrations/destination/mssql/v2/convert/AirbyteValueToStatement.kt @@ -35,6 +35,12 @@ class AirbyteValueToStatement { companion object { private val toSqlType = AirbyteTypeToSqlType() private val toSqlValue = AirbyteValueToSqlValue() + private val valueCoercingMapper = + AirbyteValueDeepCoercingMapper( + recurseIntoObjects = false, + recurseIntoArrays = false, + recurseIntoUnions = false, + ) fun PreparedStatement.setValue( idx: Int, @@ -111,7 +117,7 @@ class AirbyteValueToStatement { Types.TIMESTAMP_WITH_TIMEZONE ) ) { - val coercedValue = AirbyteValueDeepCoercingMapper().map(value, type) + val coercedValue = valueCoercingMapper.map(value, type) if (coercedValue.second.isEmpty()) { when (coercedValue.first) { is DateValue -> setAsDateValue(idx, coercedValue.first as DateValue) diff --git a/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLContainerHelper.kt b/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLContainerHelper.kt index a8cf2f3320ee1..b37c96f071fa3 100644 --- a/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLContainerHelper.kt +++ b/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLContainerHelper.kt @@ -7,7 +7,6 @@ package io.airbyte.integrations.destination.mssql.v2 import io.airbyte.cdk.load.test.util.ConfigurationUpdater import io.airbyte.integrations.destination.mssql.v2.MSSQLContainerHelper.getIpAddress import io.airbyte.integrations.destination.mssql.v2.MSSQLContainerHelper.getPort -import io.airbyte.integrations.destination.mssql.v2.MSSQLContainerHelper.testContainer import io.github.oshai.kotlinlogging.KotlinLogging import org.testcontainers.containers.MSSQLServerContainer import org.testcontainers.containers.MSSQLServerContainer.MS_SQL_SERVER_PORT @@ -41,7 +40,7 @@ object MSSQLContainerHelper { fun getPassword(): String = testContainer.password - fun getPort(): Int? = testContainer.firstMappedPort + fun getPort(): Int? = testContainer.getMappedPort(MS_SQL_SERVER_PORT) fun getIpAddress(): String? { // Ensure that the container is started first @@ -50,18 +49,22 @@ object MSSQLContainerHelper { } } -class MSSQLConfigUpdater(private val replacePort: Boolean = false) : ConfigurationUpdater { +class MSSQLConfigUpdater : ConfigurationUpdater { override fun update(config: String): String { var updatedConfig = config + + // If not running the connector in docker, we must use the mapped port to connect to the + // database. Otherwise, get the container's IP address for the host updatedConfig = - MSSQLContainerHelper.getIpAddress()?.let { config.replace("localhost", it) } - ?: updatedConfig - if (replacePort) { - updatedConfig = - getPort()?.let { config.replace("$MS_SQL_SERVER_PORT", it.toString()) } + if (System.getenv("AIRBYTE_CONNECTOR_INTEGRATION_TEST_RUNNER") != "docker") { + getPort()?.let { updatedConfig.replace("$MS_SQL_SERVER_PORT", it.toString()) } ?: updatedConfig - } + } else { + getIpAddress()?.let { config.replace("localhost", it) } ?: updatedConfig + } - return updatedConfig.replace("replace_me", MSSQLContainerHelper.getPassword()) + updatedConfig = updatedConfig.replace("replace_me", MSSQLContainerHelper.getPassword()) + logger.debug { "Using updated MSSQL configuration: $updatedConfig" } + return updatedConfig } } diff --git a/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLWriterTest.kt b/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLWriterTest.kt index e6cf196bec708..b97dbb6b49af6 100644 --- a/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLWriterTest.kt +++ b/airbyte-integrations/connectors/destination-mssql-v2/src/test-integration/kotlin/io/airbyte/integrations/destination/mssql/v2/MSSQLWriterTest.kt @@ -11,7 +11,9 @@ import io.airbyte.cdk.load.test.util.DestinationCleaner import io.airbyte.cdk.load.test.util.DestinationDataDumper import io.airbyte.cdk.load.test.util.OutputRecord import io.airbyte.cdk.load.write.BasicFunctionalityIntegrationTest +import io.airbyte.cdk.load.write.SchematizedNestedValueBehavior import io.airbyte.cdk.load.write.StronglyTyped +import io.airbyte.cdk.load.write.UnionBehavior import io.airbyte.integrations.destination.mssql.v2.config.DataSourceFactory import io.airbyte.integrations.destination.mssql.v2.config.MSSQLConfiguration import io.airbyte.integrations.destination.mssql.v2.config.MSSQLConfigurationFactory @@ -35,16 +37,19 @@ abstract class MSSQLWriterTest( dataDumper = dataDumper, destinationCleaner = dataCleaner, isStreamSchemaRetroactive = true, - supportsDedup = false, + supportsDedup = true, stringifySchemalessObjects = false, - promoteUnionToObject = true, preserveUndeclaredFields = false, commitDataIncrementally = true, allTypesBehavior = StronglyTyped(integerCanBeLarge = false), nullEqualsUnset = true, supportFileTransfer = false, envVars = emptyMap(), - configUpdater = MSSQLConfigUpdater() + configUpdater = MSSQLConfigUpdater(), + schematizedArrayBehavior = SchematizedNestedValueBehavior.STRONGLY_TYPE, + schematizedObjectBehavior = SchematizedNestedValueBehavior.PASS_THROUGH, + unionBehavior = UnionBehavior.PROMOTE_TO_OBJECT, + nullUnknownTypes = false, ) class MSSQLDataDumper : DestinationDataDumper { @@ -57,8 +62,7 @@ class MSSQLDataDumper : DestinationDataDumper { val dataSource = DataSourceFactory().dataSource(config) val output = mutableListOf() dataSource.connection.use { connection -> - val statement = connection.prepareStatement(sqlBuilder.selectAllRecords()) - statement.executeQuery().use { rs -> + SELECT_FROM.toQuery(sqlBuilder.fqTableName).executeQuery(connection) { rs -> while (rs.next()) { val objectValue = sqlBuilder.readResult(rs, sqlBuilder.finalTableSchema) val record =