Skip to content

Commit

Permalink
Made get_conn in JdbcHook threadsafe to avoid OSError: JVM is already…
Browse files Browse the repository at this point in the history
… started (apache#44718)

* refactor: Made get_conn in JdbcHook threadsafe to avoid OSError: JVM is already started when used in multithreaded environment

* Refactor: removed commented code

* refactor: Reorganized imports

* refactor: Added white line

* refactor: Fixed static checks test JdbcHook

* refactor: Added white line

* refactor: Refactored JdbcHook get_conn method using RLock as suggested by Jarek instead of wrapt synchronized decorator

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Dec 13, 2024
1 parent 999aad3 commit 2c01457
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
17 changes: 10 additions & 7 deletions providers/src/airflow/providers/jdbc/hooks/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import traceback
import warnings
from contextlib import contextmanager
from threading import RLock
from typing import TYPE_CHECKING, Any

import jaydebeapi
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
super().__init__(*args, **kwargs)
self._driver_path = driver_path
self._driver_class = driver_class
self.lock = RLock()

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
Expand Down Expand Up @@ -183,13 +185,14 @@ def get_conn(self) -> jaydebeapi.Connection:
login: str = conn.login
psw: str = conn.password

conn = jaydebeapi.connect(
jclassname=self.driver_class,
url=str(host),
driver_args=[str(login), str(psw)],
jars=self.driver_path.split(",") if self.driver_path else None,
)
return conn
with self.lock:
conn = jaydebeapi.connect(
jclassname=self.driver_class,
url=str(host),
driver_args=[str(login), str(psw)],
jars=self.driver_path.split(",") if self.driver_path else None,
)
return conn

def set_autocommit(self, conn: jaydebeapi.Connection, autocommit: bool) -> None:
"""
Expand Down
42 changes: 41 additions & 1 deletion providers/tests/jdbc/hooks/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import json
import logging
import sqlite3
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import current_thread
from time import sleep
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import jaydebeapi
import pytest
Expand All @@ -35,6 +38,7 @@


jdbc_conn_mock = Mock(name="jdbc_conn")
logger = logging.getLogger(__name__)


def get_hook(
Expand Down Expand Up @@ -229,3 +233,39 @@ def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
jdbc_hook.get_conn = lambda: connection
engine = jdbc_hook.get_sqlalchemy_engine()
assert engine.connect().connection.connection == connection

def test_get_conn_thread_safety(self):
mock_conn = MagicMock()
open_connections = 0

def connect_side_effect(*args, **kwargs):
nonlocal open_connections
open_connections += 1
logger.debug("Thread %s has %s open connections", current_thread().name, open_connections)

try:
if open_connections > 1:
raise OSError("JVM is already started")
finally:
sleep(0.1) # wait a bit before releasing the connection again
open_connections -= 1

return mock_conn

with patch.object(jaydebeapi, "connect", side_effect=connect_side_effect) as mock_connect:
jdbc_hook = get_hook()

def call_get_conn():
conn = jdbc_hook.get_conn()
assert conn is mock_conn

with ThreadPoolExecutor(max_workers=10) as executor:
futures = []

for _ in range(0, 10):
futures.append(executor.submit(call_get_conn))

for future in as_completed(futures):
future.result() # This will raise OSError if get_conn isn't threadsafe

assert mock_connect.call_count == 10

0 comments on commit 2c01457

Please sign in to comment.