diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index ba595751320b3..21a26104ddeb2 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -23,14 +23,13 @@ from sqlalchemy import select -from airflow.api_internal.internal_api_call import internal_api_call from airflow.secrets import BaseSecretsBackend from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.connection import Connection + from airflow.models import Connection class MetastoreBackend(BaseSecretsBackend): @@ -38,33 +37,29 @@ class MetastoreBackend(BaseSecretsBackend): @provide_session def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None: - return MetastoreBackend._fetch_connection(conn_id, session=session) - - @provide_session - def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: """ - Get Airflow Variable from Metadata DB. + Get Airflow Connection from Metadata DB. - :param key: Variable Key - :return: Variable Value + :param conn_id: Connection ID + :param session: SQLAlchemy Session + :return: Connection Object """ - return MetastoreBackend._fetch_variable(key=key, session=session) - - @staticmethod - @internal_api_call - @provide_session - def _fetch_connection(conn_id: str, session: Session = NEW_SESSION) -> Connection | None: - from airflow.models.connection import Connection + from airflow.models import Connection conn = session.scalar(select(Connection).where(Connection.conn_id == conn_id).limit(1)) session.expunge_all() return conn - @staticmethod - @internal_api_call @provide_session - def _fetch_variable(key: str, session: Session = NEW_SESSION) -> str | None: - from airflow.models.variable import Variable + def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: + """ + Get Airflow Variable from Metadata DB. + + :param key: Variable Key + :param session: SQLAlchemy Session + :return: Variable Value + """ + from airflow.models import Variable var_value = session.scalar(select(Variable).where(Variable.key == key).limit(1)) session.expunge_all()