Skip to content

Commit

Permalink
Recovering from a lost db connection. #11
Browse files Browse the repository at this point in the history
  • Loading branch information
thake committed Dec 26, 2020
1 parent 7741899 commit 171571e
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class SourceConnectorConfig(
parsedConfig
)

val connection by lazy {
fun openConnection() : Connection? {
val dbUri = "${dbHostName}:${dbPort}/${dbSid}"
fun openConnection(): Connection {
fun doOpenConnection(): Connection {
return DriverManager.getConnection(
"jdbc:oracle:thin:@$dbUri",
dbUser, dbPassword
Expand All @@ -48,13 +48,13 @@ class SourceConnectorConfig(
}
currentAttempt++
try {
connection = openConnection()
connection = doOpenConnection()
} catch (e: SQLException) {
logger.error(e) { "Couldn't connect to database with url $dbUri. Attempt $currentAttempt." }

}
}
connection ?: throw SQLException("Couldn't initialize Connection to $dbUri after $dbAttempts attempts.")
return connection
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,23 @@ import java.sql.SQLException
import java.util.*

private val logger = KotlinLogging.logger {}

class NoConnectionToDatabase : RuntimeException()
sealed class TaskState
object StoppedState : TaskState()
data class StartedState(val config: SourceConnectorConfig, val context: SourceTaskContext) : TaskState() {
val connection: Connection by lazy {
config.connection
private var currentConnection: Connection? = null

val connection : Connection?
get() {
var connection = currentConnection
if(connection != null && (connection.isClosed || !connection.isValid(1000))){
connection = null
}
if(connection == null){
connection = config.openConnection()
currentConnection = connection
}
return connection
}
var offset: Offset?
val nameService: ConnectNameService = SourceDatabaseNameService(config.dbName)
Expand All @@ -31,8 +42,9 @@ data class StartedState(val config: SourceConnectorConfig, val context: SourceTa
private val connectSchemaFactory = ConnectSchemaFactory(nameService, isEmittingTombstones = config.isTombstonesOnDelete)

init {
val workingConnection = connection ?: error("No connection to database possible at startup time. Aborting.")
fun getTablesForOwner(owner: String): List<TableId> {
return connection.metaData.getTables(null, owner, null, arrayOf("TABLE")).use {
return workingConnection.metaData.getTables(null, owner, null, arrayOf("TABLE")).use {
val result = mutableListOf<TableId>()
while (it.next()) {
result.add(TableId(owner, it.getString(3)))
Expand Down Expand Up @@ -101,7 +113,8 @@ data class StartedState(val config: SourceConnectorConfig, val context: SourceTa
fun poll(): List<SourceRecord> {
logger.debug { "Polling database for new changes ..." }
fun doPoll(): List<PollResult> {
source.maybeStartQuery(connection)
val workingConnection = connection ?: throw NoConnectionToDatabase()
source.maybeStartQuery(workingConnection)
val result = source.poll()
//Advance the offset and source
offset = source.getOffset()
Expand Down Expand Up @@ -139,7 +152,7 @@ data class StartedState(val config: SourceConnectorConfig, val context: SourceTa
fun stop() {
logger.info { "Kafka connect oracle task will be stopped" }
this.source.close()
this.connection.close()
this.connection?.close()
}

}
Expand Down Expand Up @@ -172,12 +185,15 @@ class SourceTask : SourceTask() {

@Throws(InterruptedException::class)
override fun poll(): List<SourceRecord> {
try {
return try {
val currState = state
return if (currState is StartedState) currState.poll() else throw IllegalStateException("Task has not been started")
if (currState is StartedState) currState.poll() else throw IllegalStateException("Task has not been started")
} catch (e: SQLException) {
logger.info(e) { "SQLException thrown. This is most probably due to an error while stopping." }
return Collections.emptyList()
Collections.emptyList()
} catch (e : NoConnectionToDatabase){
logger.info(e) {"Currently no connection to the database can be established. Returning 0 records to kafka."}
Collections.emptyList()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.github.thake.logminer.kafka.connect.initial
import com.github.thake.logminer.kafka.connect.*
import mu.KotlinLogging
import java.sql.Connection
import java.sql.SQLException

private val logger = KotlinLogging.logger {}
class SelectSource(
Expand All @@ -24,10 +25,17 @@ class SelectSource(
override fun getOffset() = lastOffset

override fun maybeStartQuery(db: Connection) {
val tableFetcher = currentTableFetcher
if(tableFetcher != null && tableFetcher.conn != db){
tableFetcher.close()
currentTableFetcher = null
}
if (currentTableFetcher == null) {
val offset = FetcherOffset(determineTableToFetch(), determineAsOfScn(db), lastOffset?.rowId)
logger.debug { "Starting new table fetcher with offset $offset" }
currentTableFetcher = TableFetcher(
db,
FetcherOffset(determineTableToFetch(), determineAsOfScn(db), lastOffset?.rowId),
offset,
schemaService = schemaService
)
}
Expand All @@ -52,36 +60,42 @@ class SelectSource(
}

override fun poll(): List<PollResult> {
var fetcher = currentTableFetcher ?: throw IllegalStateException("maybeStartQuery hasn't been called")
val result = mutableListOf<PollResult>()
while (result.size < batchSize && continuePolling) {
val nextRecord = fetcher.poll()
if (nextRecord != null) {
lastOffset = nextRecord.offset as SelectOffset
result.add(nextRecord)
} else {
//No new records from the current table. Close the fetcher and check the next table
fetcher.close()
val newIndex = tablesToFetch.indexOf(fetcher.fetcherOffset.table) + 1
if (newIndex < tablesToFetch.size) {
fetcher = TableFetcher(
fetcher.conn,
FetcherOffset(tablesToFetch[newIndex], fetcher.fetcherOffset.asOfScn, null),
schemaService
)
currentTableFetcher = fetcher
//Exit the loop to return the current result set if it is not empty.
if (result.isNotEmpty()) {
break
}
try{
var fetcher = currentTableFetcher ?: throw IllegalStateException("maybeStartQuery hasn't been called")
val result = mutableListOf<PollResult>()
while (result.size < batchSize && continuePolling) {
val nextRecord = fetcher.poll()
if (nextRecord != null) {
lastOffset = nextRecord.offset as SelectOffset
result.add(nextRecord)
} else {
//no more records to poll all tables polled
logger.debug { "Stopping fetching from tables as fetch from table ${fetcher.fetcherOffset.table} did not provide any more results." }
continuePolling = false
//No new records from the current table. Close the fetcher and check the next table
fetcher.close()
val newIndex = tablesToFetch.indexOf(fetcher.fetcherOffset.table) + 1
if (newIndex < tablesToFetch.size) {
fetcher = TableFetcher(
fetcher.conn,
FetcherOffset(tablesToFetch[newIndex], fetcher.fetcherOffset.asOfScn, null),
schemaService
)
currentTableFetcher = fetcher
//Exit the loop to return the current result set if it is not empty.
if (result.isNotEmpty()) {
break
}
} else {
//no more records to poll all tables polled
logger.debug { "Stopping fetching from tables as fetch from table ${fetcher.fetcherOffset.table} did not provide any more results." }
continuePolling = false
}
}
}
return result
}catch (e : SQLException){
currentTableFetcher = null
continuePolling = true
throw e
}
return result
}

override fun close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class TableFetcher(val conn: Connection, val fetcherOffset: FetcherOffset, val s
private val schemaDefinition: SchemaDefinition

init {

fun determineQuery(): String {
val rowIdCondition = fetcherOffset.rowId?.let { "WHERE ROWID > '$it'" } ?: ""
return "SELECT t.*, ROWID, ORA_ROWSCN FROM ${fetcherOffset.table.fullName} AS OF SCN ${fetcherOffset.asOfScn} t $rowIdCondition order by ROWID ASC"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.github.thake.logminer.kafka.connect

import io.kotest.matchers.collections.shouldBeEmpty
import io.kotest.matchers.collections.shouldHaveSize
import org.apache.kafka.connect.source.SourceRecord
import org.apache.kafka.connect.source.SourceTaskContext
import org.apache.kafka.connect.storage.OffsetStorageReader
Expand All @@ -8,6 +10,7 @@ import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.LoggerFactory
import org.testcontainers.junit.jupiter.Testcontainers
import java.sql.Connection
import java.util.*
Expand All @@ -17,6 +20,7 @@ class SourceTaskTest : AbstractIntegrationTest() {
private lateinit var sourceTask: SourceTask
private lateinit var offsetManager: MockOffsetStorageReader
private lateinit var defaultConfig: Map<String, String>
private val log = LoggerFactory.getLogger(SourceTaskTest::class.java)

private class TestSourceTaskContext(
val configs: Map<String, String>,
Expand Down Expand Up @@ -101,7 +105,7 @@ class SourceTaskTest : AbstractIntegrationTest() {
//Now add new rows
(100 until 200).forEach { modifyingConnection.insertRow(it) }
//Now continue reading until poll returns an empty list
result.addAll(readAllSourceRecords(sourceTask))
result.addAll(sourceTask.readAllSourceRecords())
assertEquals(200, result.size)
}

Expand Down Expand Up @@ -136,7 +140,7 @@ class SourceTaskTest : AbstractIntegrationTest() {
//Now add new rows
(100 until 200).forEach { modifyingConnection.insertRow(it) }
//Now continue reading until poll returns an empty list
result.addAll(readAllSourceRecords(sourceTask))
result.addAll(sourceTask.readAllSourceRecords())
assertEquals(100, result.size)
result.forEach { record ->
assertEquals(CDC_TYPE, record.sourceOffset()["type"])
Expand All @@ -160,12 +164,12 @@ class SourceTaskTest : AbstractIntegrationTest() {
}
)
)
val result = readAllSourceRecords(sourceTask).toMutableList()
val result = sourceTask.readAllSourceRecords().toMutableList()
assertEquals(100, result.size, "Result does not contain the same size as the number of inserted entries.")
//Now add new rows
(100 until 200).forEach { modifyingConnection.insertRow(it) }
//Now continue reading until poll returns an empty list
result.addAll(readAllSourceRecords(sourceTask))
result.addAll(sourceTask.readAllSourceRecords())
assertEquals(200, result.size)
result.forEach { record ->
assertEquals(CDC_TYPE, record.sourceOffset()["type"])
Expand Down Expand Up @@ -202,14 +206,51 @@ class SourceTaskTest : AbstractIntegrationTest() {
//Now add new rows
(100 until 200).forEach { modifyingConnection.insertRow(it) }
//Now continue reading until poll returns an empty list
result.addAll(readAllSourceRecords(sourceTask))
result.addAll(sourceTask.readAllSourceRecords())
assertEquals(200, result.size)
}
@Test
fun testResumeDuringCDCAfterDbConnectionLost() {
sourceTask.start(
createConfiguration(
mapOf(
SourceConnectorConfig.BATCH_SIZE to "10"
)
)
)
val modifyingConnection = openConnection()
//Initial state
(0 until 10).forEach { modifyingConnection.insertRow(it, SECOND_TABLE) }
val result = sourceTask.poll().toMutableList()

//Check that the batch size is correct
result.shouldHaveSize(10)

//Now add new rows
(100 until 200).forEach { modifyingConnection.insertRow(it) }
//Fetch the next 10 rows. These should be the first cdc rows
result.addAll(sourceTask.poll())
result.shouldHaveSize(20)

log.info("Stopping oracle DB to simulate a lost connection")
val stopResult = oracle.execInContainer("/bin/bash","-c","service oracle-xe stop")
log.info("Stop exited with code ${stopResult.exitCode} and log output: ${stopResult.stdout} Err: ${stopResult.stderr}")
//try to poll now. Should return in an empty result
val expectedEmptyResult = sourceTask.poll()
expectedEmptyResult.shouldBeEmpty()
//Starting again
val startResult = oracle.execInContainer("/bin/bash", "-c", "service oracle-xe start")
log.info("Start exited with code ${startResult.exitCode} and log output: ${startResult.stdout} Err: ${startResult.stderr}")

//Now continue reading until poll returns an empty list
result.addAll(sourceTask.readAllSourceRecords())
assertEquals(110, result.size)
}

private fun readAllSourceRecords(sourceTask: SourceTask): List<SourceRecord> {
private fun SourceTask.readAllSourceRecords(): List<SourceRecord> {
val result = mutableListOf<SourceRecord>()
while (true) {
val currentResult = sourceTask.poll()
val currentResult = poll()
if (currentResult.isEmpty()) {
break
} else {
Expand Down

0 comments on commit 171571e

Please sign in to comment.