From 2c0145766a12a55a6a5096a7f78a15654b0ea129 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 13 Dec 2024 17:53:50 +0100 Subject: [PATCH] Made get_conn in JdbcHook threadsafe to avoid OSError: JVM is already started (#44718) * refactor: Made get_conn in JdbcHook threadsafe to avoid OSError: JVM is already started when used in multithreaded environment * Refactor: removed commented code * refactor: Reorganized imports * refactor: Added white line * refactor: Fixed static checks test JdbcHook * refactor: Added white line * refactor: Refactored JdbcHook get_conn method using RLock as suggested by Jarek instead of wrapt synchronized decorator --------- Co-authored-by: David Blain --- .../src/airflow/providers/jdbc/hooks/jdbc.py | 17 ++++---- providers/tests/jdbc/hooks/test_jdbc.py | 42 ++++++++++++++++++- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/providers/src/airflow/providers/jdbc/hooks/jdbc.py b/providers/src/airflow/providers/jdbc/hooks/jdbc.py index 47fcbe8e039c9..808b946bd9762 100644 --- a/providers/src/airflow/providers/jdbc/hooks/jdbc.py +++ b/providers/src/airflow/providers/jdbc/hooks/jdbc.py @@ -20,6 +20,7 @@ import traceback import warnings from contextlib import contextmanager +from threading import RLock from typing import TYPE_CHECKING, Any import jaydebeapi @@ -98,6 +99,7 @@ def __init__( super().__init__(*args, **kwargs) self._driver_path = driver_path self._driver_class = driver_class + self.lock = RLock() @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: @@ -183,13 +185,14 @@ def get_conn(self) -> jaydebeapi.Connection: login: str = conn.login psw: str = conn.password - conn = jaydebeapi.connect( - jclassname=self.driver_class, - url=str(host), - driver_args=[str(login), str(psw)], - jars=self.driver_path.split(",") if self.driver_path else None, - ) - return conn + with self.lock: + conn = jaydebeapi.connect( + jclassname=self.driver_class, + url=str(host), + driver_args=[str(login), str(psw)], + jars=self.driver_path.split(",") if self.driver_path else None, + ) + return conn def set_autocommit(self, conn: jaydebeapi.Connection, autocommit: bool) -> None: """ diff --git a/providers/tests/jdbc/hooks/test_jdbc.py b/providers/tests/jdbc/hooks/test_jdbc.py index cfb27934d86da..73015b5b522ab 100644 --- a/providers/tests/jdbc/hooks/test_jdbc.py +++ b/providers/tests/jdbc/hooks/test_jdbc.py @@ -20,8 +20,11 @@ import json import logging import sqlite3 +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import current_thread +from time import sleep from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import jaydebeapi import pytest @@ -35,6 +38,7 @@ jdbc_conn_mock = Mock(name="jdbc_conn") +logger = logging.getLogger(__name__) def get_hook( @@ -229,3 +233,39 @@ def test_get_sqlalchemy_engine_verify_creator_is_being_used(self): jdbc_hook.get_conn = lambda: connection engine = jdbc_hook.get_sqlalchemy_engine() assert engine.connect().connection.connection == connection + + def test_get_conn_thread_safety(self): + mock_conn = MagicMock() + open_connections = 0 + + def connect_side_effect(*args, **kwargs): + nonlocal open_connections + open_connections += 1 + logger.debug("Thread %s has %s open connections", current_thread().name, open_connections) + + try: + if open_connections > 1: + raise OSError("JVM is already started") + finally: + sleep(0.1) # wait a bit before releasing the connection again + open_connections -= 1 + + return mock_conn + + with patch.object(jaydebeapi, "connect", side_effect=connect_side_effect) as mock_connect: + jdbc_hook = get_hook() + + def call_get_conn(): + conn = jdbc_hook.get_conn() + assert conn is mock_conn + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + + for _ in range(0, 10): + futures.append(executor.submit(call_get_conn)) + + for future in as_completed(futures): + future.result() # This will raise OSError if get_conn isn't threadsafe + + assert mock_connect.call_count == 10