Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mssql): use pyodbc instead for improved reliability over pymssql #648

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion modules/mssql/testcontainers/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class SqlServerContainer(DbContainer):
... engine = sqlalchemy.create_engine(mssql.get_connection_url())
... with engine.begin() as connection:
... result = connection.execute(sqlalchemy.text("select @@VERSION"))

Notes
-----
Requires `ODBC Driver 17 for SQL Server <https://docs.microsoft.com/en-us/sql/connect/odbc/
linux-mac/installing-the-microsoft-odbc-driver-for-sql-server>`_.
"""

def __init__(
Expand All @@ -31,6 +36,7 @@ def __init__(
port: int = 1433,
dbname: str = "tempdb",
dialect: str = "mssql+pymssql",
driver: str = "ODBC Driver 17 for SQL Server",
**kwargs,
) -> None:
raise_for_deprecated_parameter(kwargs, "user", "username")
Expand All @@ -43,6 +49,7 @@ def __init__(
self.username = username
self.dbname = dbname
self.dialect = dialect
self.driver = driver

def _configure(self) -> None:
self.with_env("SA_PASSWORD", self.password)
Expand All @@ -56,6 +63,8 @@ def _connect(self) -> None:
assert status == 0, "Cannot run 'SELECT 1': container is not ready"

def get_connection_url(self) -> str:
return super()._create_connection_url(
base_url = super()._create_connection_url(
dialect=self.dialect, username=self.username, password=self.password, dbname=self.dbname, port=self.port
)
url = base_url + f"?driver={'+'.join(self.driver.split(' '))}"
return url
55 changes: 55 additions & 0 deletions mssql/testcontainers/mssql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from os import environ
from typing import Optional
from testcontainers.core.generic import DbContainer
from testcontainers.core.utils import raise_for_deprecated_parameter


class SqlServerContainer(DbContainer):
"""
Microsoft SQL Server database container.

Example:

.. doctest::

>>> import sqlalchemy
>>> from testcontainers.mssql import SqlServerContainer

>>> with SqlServerContainer() as mssql:
... engine = sqlalchemy.create_engine(mssql.get_connection_url())
... result = engine.execute(sqlalchemy.text("select @@VERSION"))
Notes
-----
Requires `ODBC Driver 17 for SQL Server <https://docs.microsoft.com/en-us/sql/connect/odbc/
linux-mac/installing-the-microsoft-odbc-driver-for-sql-server>`_.
"""

def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest",
username: str = "SA", password: Optional[str] = None, port: int = 1433,
dbname: str = "tempdb", dialect: str = 'mssql+pymssql', driver: str = "ODBC Driver 17 for SQL Server", **kwargs) -> None:
raise_for_deprecated_parameter(kwargs, "user", "username")
super(SqlServerContainer, self).__init__(image, **kwargs)

self.port = port
self.with_exposed_ports(self.port)

self.password = password or environ.get("SQLSERVER_PASSWORD", "1Secure*Password1")
self.username = username
self.dbname = dbname
self.dialect = dialect
self.driver = driver

def _configure(self) -> None:
self.with_env("SA_PASSWORD", self.password)
self.with_env("SQLSERVER_USER", self.username)
self.with_env("SQLSERVER_DBNAME", self.dbname)
self.with_env("ACCEPT_EULA", 'Y')

def get_connection_url(self) -> str:
base_url = super(SqlServerContainer, self)._create_connection_url(
dialect=self.dialect, username=self.username, password=self.password,
db_name=self.dbname, port=self.port
)
url = base_url + f"?driver={'+'.join(self.driver.split(' '))}"
return url