From ca2e737e3a4c3b3af6280e16e38ca2ccd1ea4163 Mon Sep 17 00:00:00 2001 From: disheng Date: Fri, 3 Jan 2025 23:19:10 +0800 Subject: [PATCH] Added backend support for MySQL SSL feature --- ibis-server/app/model/__init__.py | 2 ++ ibis-server/app/model/data_source.py | 40 +++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 98e3a081f..066348b13 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -95,6 +95,8 @@ class MySqlConnectionInfo(BaseModel): database: SecretStr user: SecretStr password: SecretStr + ssl_mode: SecretStr = Field(alias="sslMode") + ssl_ca: SecretStr | None = Field(alias="sslCA", default=None) kwargs: dict[str, str] | None = Field( description="Additional keyword arguments to pass to PyMySQL", default=None ) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index ba282e973..dc5f4f9c8 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -1,8 +1,10 @@ from __future__ import annotations import base64 +import ssl from enum import Enum, StrEnum, auto from json import loads +from typing import Optional import ibis from google.oauth2 import service_account @@ -30,6 +32,12 @@ ) +class SSLMode(str, Enum): + DISABLE = "Disable" + REQUIRE = "Require" + VERIFY_CA = "Verify CA" + + class DataSource(StrEnum): bigquery = auto() canner = auto() @@ -123,15 +131,19 @@ def get_mssql_connection(cls, info: MSSqlConnectionInfo) -> BaseBackend: **info.kwargs if info.kwargs else dict(), ) - @staticmethod - def get_mysql_connection(info: MySqlConnectionInfo) -> BaseBackend: + @classmethod + def get_mysql_connection(cls, info: MySqlConnectionInfo) -> BaseBackend: + ssl_context = cls._create_ssl_context(info) + kwargs = {"ssl": ssl_context} if ssl_context else {} + if info.kwargs: + kwargs.update(info.kwargs) return ibis.mysql.connect( host=info.host.get_secret_value(), port=int(info.port.get_secret_value()), database=info.database.get_secret_value(), user=info.user.get_secret_value(), password=info.password.get_secret_value(), - **info.kwargs if info.kwargs else dict(), + **kwargs, ) @staticmethod @@ -168,3 +180,25 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend: @staticmethod def _escape_special_characters_for_odbc(value: str) -> str: return "{" + value.replace("}", "}}") + "}" + + @staticmethod + def _create_ssl_context(info: ConnectionInfo) -> Optional[ssl.SSLContext]: + ssl_mode = info.ssl_mode.get_secret_value() + + if ssl_mode == SSLMode.DISABLE: + return None + + ctx = ssl.create_default_context() + ctx.check_hostname = False + + if ssl_mode == SSLMode.REQUIRE: + ctx.verify_mode = ssl.CERT_NONE + elif ssl_mode == SSLMode.VERIFY_CA: + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations( + cadata=base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8") + if info.ssl_ca + else None + ) + + return ctx