Skip to content

Commit

Permalink
Added backend support for MySQL SSL feature
Browse files Browse the repository at this point in the history
  • Loading branch information
ongdisheng committed Jan 3, 2025
1 parent fbf56c8 commit ca2e737
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
2 changes: 2 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class MySqlConnectionInfo(BaseModel):
database: SecretStr
user: SecretStr
password: SecretStr
ssl_mode: SecretStr = Field(alias="sslMode")
ssl_ca: SecretStr | None = Field(alias="sslCA", default=None)
kwargs: dict[str, str] | None = Field(
description="Additional keyword arguments to pass to PyMySQL", default=None
)
Expand Down
40 changes: 37 additions & 3 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import base64
import ssl
from enum import Enum, StrEnum, auto
from json import loads
from typing import Optional

import ibis
from google.oauth2 import service_account
Expand Down Expand Up @@ -30,6 +32,12 @@
)


class SSLMode(str, Enum):
DISABLE = "Disable"
REQUIRE = "Require"
VERIFY_CA = "Verify CA"


class DataSource(StrEnum):
bigquery = auto()
canner = auto()
Expand Down Expand Up @@ -123,15 +131,19 @@ def get_mssql_connection(cls, info: MSSqlConnectionInfo) -> BaseBackend:
**info.kwargs if info.kwargs else dict(),
)

@staticmethod
def get_mysql_connection(info: MySqlConnectionInfo) -> BaseBackend:
@classmethod
def get_mysql_connection(cls, info: MySqlConnectionInfo) -> BaseBackend:
ssl_context = cls._create_ssl_context(info)
kwargs = {"ssl": ssl_context} if ssl_context else {}
if info.kwargs:
kwargs.update(info.kwargs)
return ibis.mysql.connect(
host=info.host.get_secret_value(),
port=int(info.port.get_secret_value()),
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=info.password.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
**kwargs,
)

@staticmethod
Expand Down Expand Up @@ -168,3 +180,25 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
@staticmethod
def _escape_special_characters_for_odbc(value: str) -> str:
return "{" + value.replace("}", "}}") + "}"

@staticmethod
def _create_ssl_context(info: ConnectionInfo) -> Optional[ssl.SSLContext]:
ssl_mode = info.ssl_mode.get_secret_value()

if ssl_mode == SSLMode.DISABLE:
return None

ctx = ssl.create_default_context()
ctx.check_hostname = False

if ssl_mode == SSLMode.REQUIRE:
ctx.verify_mode = ssl.CERT_NONE
elif ssl_mode == SSLMode.VERIFY_CA:
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(
cadata=base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8")
if info.ssl_ca
else None
)

return ctx

0 comments on commit ca2e737

Please sign in to comment.