Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Destination MSSQL] add dedup #51612

Merged
merged 15 commits into from
Jan 21, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ import jakarta.inject.Singleton
import java.util.UUID
import javax.sql.DataSource

const val CHECK_TABLE_STATEMENT = """
CREATE TABLE ? (test int);
DROP TABLE ?;
"""

@Singleton
class MSSQLChecker(private val dataSourceFactory: MSSQLDataSourceFactory) :
DestinationChecker<MSSQLConfiguration> {
Expand All @@ -24,10 +19,13 @@ class MSSQLChecker(private val dataSourceFactory: MSSQLDataSourceFactory) :
val testTableName = "check_test_${UUID.randomUUID()}"
val fullyQualifiedTableName = "[${config.rawDataSchema}].[${testTableName}]"
dataSource.connection.use { connection ->
connection.prepareStatement(CHECK_TABLE_STATEMENT.trimIndent()).use { statement ->
statement.setString(1, fullyQualifiedTableName)
statement.setString(2, fullyQualifiedTableName)
statement.executeUpdate()
connection.createStatement().use { statement ->
statement.executeUpdate(
"""
CREATE TABLE ${fullyQualifiedTableName} (test int);
DROP TABLE ${fullyQualifiedTableName};
""".trimIndent(),
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

package io.airbyte.integrations.destination.mssql.v2

import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.Dedupe
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.Overwrite
import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.FieldType
Expand All @@ -25,28 +28,84 @@ import io.airbyte.protocol.models.Jsons
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMeta
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
import java.lang.ArithmeticException
import java.sql.Connection
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.util.UUID

fun <T> String.executeQuery(connection: Connection, vararg args: String, f: (ResultSet) -> T): T {
connection.prepareStatement(this.trimIndent()).use { statement ->
args.forEachIndexed { index, arg -> statement.setString(index + 1, arg) }
return statement.executeQuery().use(f)
}
}

fun String.executeUpdate(connection: Connection, vararg args: String) {
connection.prepareStatement(this.trimIndent()).use { statement ->
args.forEachIndexed { index, arg -> statement.setString(index + 1, arg) }
statement.executeUpdate()
}
}

fun String.toQuery(vararg args: String): String = this.trimIndent().replace("?", "%s").format(*args)

const val GET_EXISTING_SCHEMA_QUERY =
"""
SELECT COLUMN_NAME, DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION ASC
"""
SELECT COLUMN_NAME, DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION ASC
"""

const val CREATE_SCHEMA_QUERY =
"""
DECLARE @Schema VARCHAR(MAX) = ?
IF NOT EXISTS (SELECT name FROM sys.schemas WHERE name = @Schema)
BEGIN
EXEC ('CREATE SCHEMA ' + @Schema);
END
"""

const val ALTER_TABLE_ADD = """
ALTER TABLE ?
ADD ? ? NULL;
"""
const val ALTER_TABLE_DROP = """
ALTER TABLE ?
DROP COLUMN ?;
"""
const val ALTER_TABLE_MODIFY = """
ALTER TABLE ?
ALTER COLUMN ? ? NULL;
"""

const val DELETE_WHERE_COL_IS_NOT_NULL =
"""
DELETE FROM ?
WHERE ? is not NULL
"""

const val DELETE_WHERE_COL_LESS_THAN = """
DELETE FROM ?
WHERE ? < ?
"""

const val SELECT_FROM = """
SELECT *
FROM ?
"""

class MSSQLQueryBuilder(
config: MSSQLConfiguration,
private val stream: DestinationStream,
) {

companion object {

const val AIRBYTE_RAW_ID = "_airbyte_raw_id"
const val AIRBYTE_EXTRACTED_AT = "_airbyte_extracted_at"
const val AIRBYTE_META = "_airbyte_meta"
const val AIRBYTE_GENERATION_ID = "_airbyte_generation_id"
const val AIRBYTE_CDC_DELETED_AT = "_ab_cdc_deleted_at"
const val DEFAULT_SEPARATOR = ",\n "

val airbyteFinalTableFields =
Expand Down Expand Up @@ -79,19 +138,29 @@ class MSSQLQueryBuilder(

private val outputSchema: String = stream.descriptor.namespace ?: config.schema
private val tableName: String = stream.descriptor.name
private val fqTableName = "$outputSchema.$tableName"
val fqTableName = "$outputSchema.$tableName"
private val uniquenessKey: List<String> =
when (stream.importType) {
is Dedupe ->
if ((stream.importType as Dedupe).primaryKey.isNotEmpty()) {
(stream.importType as Dedupe).primaryKey.map { it.joinToString(".") }
} else {
listOf((stream.importType as Dedupe).cursor.joinToString("."))
}
Append -> emptyList()
Overwrite -> emptyList()
}

private val toSqlType = AirbyteTypeToSqlType()
private val toMssqlType = SqlTypeToMssqlType()

val finalTableSchema: List<NamedField> =
airbyteFinalTableFields + extractFinalTableSchema(stream.schema)
val hasCdc: Boolean = finalTableSchema.any { it.name == AIRBYTE_CDC_DELETED_AT }

fun getExistingSchema(statement: PreparedStatement): List<NamedSqlField> {
private fun getExistingSchema(connection: Connection): List<NamedSqlField> {
val fields = mutableListOf<NamedSqlField>()
statement.setString(1, outputSchema)
statement.setString(2, tableName)
statement.executeQuery().use { rs ->
GET_EXISTING_SCHEMA_QUERY.executeQuery(connection, outputSchema, tableName) { rs ->
while (rs.next()) {
val name = rs.getString("COLUMN_NAME")
val type = MssqlType.valueOf(rs.getString("DATA_TYPE").uppercase())
Expand All @@ -101,53 +170,69 @@ class MSSQLQueryBuilder(
return fields
}

fun getSchema(): List<NamedSqlField> =
private fun getSchema(): List<NamedSqlField> =
finalTableSchema.map {
NamedSqlField(it.name, toMssqlType.convert(toSqlType.convert(it.type.type)))
}

fun alterTableIfNeeded(
existingSchema: List<NamedSqlField>,
expectedSchema: List<NamedSqlField>,
): String? {
fun updateSchema(connection: Connection) {
val existingSchema = getExistingSchema(connection)
val expectedSchema = getSchema()

val existingFields = existingSchema.associate { it.name to it.type }
val expectedFields = expectedSchema.associate { it.name to it.type }

if (existingFields == expectedFields) {
return null
return
}

val toDelete = existingFields.filter { it.key !in expectedFields }
val toAdd = expectedFields.filter { it.key !in existingFields }
val toAlter =
expectedFields.filter { it.key in existingFields && it.value != existingFields[it.key] }
return StringBuilder()
.apply {
toDelete.entries.forEach {
appendLine("ALTER TABLE $fqTableName")
appendLine("DROP COLUMN ${it.key};")
}
toAdd.entries.forEach {
appendLine("ALTER TABLE $fqTableName")
appendLine("ADD ${it.key} ${it.value.sqlString} NULL;")
}
toAlter.entries.forEach {
appendLine("ALTER TABLE $fqTableName")
appendLine("ALTER COLUMN ${it.key} ${it.value.sqlString} NULL;")

val query =
StringBuilder()
.apply {
toDelete.entries.forEach {
appendLine(ALTER_TABLE_DROP.toQuery(fqTableName, it.key))
}
toAdd.entries.forEach {
appendLine(ALTER_TABLE_ADD.toQuery(fqTableName, it.key, it.value.sqlString))
}
toAlter.entries.forEach {
appendLine(
ALTER_TABLE_MODIFY.toQuery(fqTableName, it.key, it.value.sqlString)
)
}
}
}
.toString()
.toString()

query.executeUpdate(connection)
}

fun createFinalTableIfNotExists(): String =
createTableIfNotExists(fqTableName, finalTableSchema)
fun createTableIfNotExists(connection: Connection) {
CREATE_SCHEMA_QUERY.executeUpdate(connection, outputSchema)

fun createFinalSchemaIfNotExists(): String = createSchemaIfNotExists(outputSchema)
connection.createStatement().use {
it.executeUpdate(createTableIfNotExists(fqTableName, finalTableSchema))
}
}

fun getFinalTableInsertColumnHeader(): String =
getFinalTableInsertColumnHeader(fqTableName, finalTableSchema)

fun deletePreviousGenerations(minGenerationId: Long): String =
deleteWhere(fqTableName, minGenerationId)
fun deleteCdc(connection: Connection) =
DELETE_WHERE_COL_IS_NOT_NULL.toQuery(fqTableName, AIRBYTE_CDC_DELETED_AT)
.executeUpdate(connection)

fun deletePreviousGenerations(connection: Connection, minGenerationId: Long) =
DELETE_WHERE_COL_LESS_THAN.toQuery(
fqTableName,
AIRBYTE_GENERATION_ID,
minGenerationId.toString()
)
.executeUpdate(connection)

fun populateStatement(
statement: PreparedStatement,
Expand Down Expand Up @@ -216,46 +301,63 @@ class MSSQLQueryBuilder(
return ObjectValue.from(valueMap)
}

fun selectAllRecords(): String = "SELECT * FROM $fqTableName"

private fun createSchemaIfNotExists(schema: String): String =
"""
IF NOT EXISTS (SELECT name FROM sys.schemas WHERE name = '$schema')
BEGIN
EXEC ('CREATE SCHEMA $schema');
END
""".trimIndent()
private fun createTableIfNotExists(fqTableName: String, schema: List<NamedField>): String {
val index =
if (uniquenessKey.isNotEmpty())
createIndex(fqTableName, uniquenessKey, clustered = false)
else ""
val cdcIndex = if (hasCdc) createIndex(fqTableName, listOf(AIRBYTE_CDC_DELETED_AT)) else ""

private fun createTableIfNotExists(fqTableName: String, schema: List<NamedField>): String =
"""
return """
IF OBJECT_ID('$fqTableName') IS NULL
BEGIN
CREATE TABLE $fqTableName
(
${airbyteTypeToSqlSchema(schema)}
);
$index;
$cdcIndex;
END
""".trimIndent()
}

private fun deleteWhere(fqTableName: String, minGenerationId: Long) =
"""
DELETE FROM $fqTableName
WHERE $AIRBYTE_GENERATION_ID < $minGenerationId
""".trimIndent()
private fun createIndex(
fqTableName: String,
columns: List<String>,
clustered: Boolean = false
): String {
val name = "${fqTableName.replace('.', '_')}_${columns.hashCode()}"
val indexType = if (clustered) "CLUSTERED" else ""
return "CREATE $indexType INDEX $name ON $fqTableName (${columns.joinToString(", ")})"
}

private fun getFinalTableInsertColumnHeader(
fqTableName: String,
schema: List<NamedField>
): String {
return StringBuilder()
.apply {
append("INSERT INTO $fqTableName(")
append(schema.map { it.name }.joinToString(", "))
append(") VALUES (")
append(schema.map { "?" }.joinToString(", "))
append(")")
}
.toString()
val columns = schema.joinToString(", ") { it.name }
val templateColumns = schema.joinToString(", ") { "?" }
return if (uniquenessKey.isEmpty()) {
"""
INSERT INTO $fqTableName ($columns)
SELECT table_value.*
FROM (VALUES ($templateColumns)) table_value($columns)
"""
} else {
val uniquenessConstraint =
uniquenessKey.joinToString(" AND ") { "Target.$it = Source.$it" }
val updateStatement = schema.joinToString(", ") { "${it.name} = Source.${it.name}" }
"""
MERGE INTO $fqTableName AS Target
USING (VALUES ($templateColumns)) AS Source ($columns)
ON $uniquenessConstraint
WHEN MATCHED THEN
UPDATE SET $updateStatement
WHEN NOT MATCHED BY TARGET THEN
INSERT ($columns) VALUES ($columns)
;
""".trimIndent()
}
}

private fun extractFinalTableSchema(schema: AirbyteType): List<NamedField> =
Expand Down
Loading
Loading