From 9b1797fe4d9b3060010b4482d4d099a1857200af Mon Sep 17 00:00:00 2001 From: MarkBNinetyOne Date: Mon, 15 Jul 2024 17:49:14 +0200 Subject: [PATCH] use pyodbc instead --- mssql/testcontainers/mssql/__init__.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py index 9de6edf0..e28db35f 100644 --- a/mssql/testcontainers/mssql/__init__.py +++ b/mssql/testcontainers/mssql/__init__.py @@ -15,15 +15,18 @@ class SqlServerContainer(DbContainer): >>> import sqlalchemy >>> from testcontainers.mssql import SqlServerContainer - >>> with SqlServerContainer() as mssql: - ... engine = sqlalchemy.create_engine(mssql.get_connection_url()) - ... with engine.begin() as connection: - ... result = connection.execute(sqlalchemy.text("select @@VERSION")) + >>> with SqlServerContainer() as mssql: + ... engine = sqlalchemy.create_engine(mssql.get_connection_url()) + ... result = engine.execute(sqlalchemy.text("select @@VERSION")) + Notes + ----- + Requires `ODBC Driver 17 for SQL Server `_. """ def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", username: str = "SA", password: Optional[str] = None, port: int = 1433, - dbname: str = "tempdb", dialect: str = 'mssql+pymssql', **kwargs) -> None: + dbname: str = "tempdb", dialect: str = 'mssql+pymssql', driver: str = "ODBC Driver 17 for SQL Server", **kwargs) -> None: raise_for_deprecated_parameter(kwargs, "user", "username") super(SqlServerContainer, self).__init__(image, **kwargs) @@ -34,6 +37,7 @@ def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", self.username = username self.dbname = dbname self.dialect = dialect + self.driver = driver def _configure(self) -> None: self.with_env("SA_PASSWORD", self.password) @@ -42,7 +46,10 @@ def _configure(self) -> None: self.with_env("ACCEPT_EULA", 'Y') def get_connection_url(self) -> str: - return super()._create_connection_url( + base_url = super(SqlServerContainer, self)._create_connection_url( dialect=self.dialect, username=self.username, password=self.password, - dbname=self.dbname, port=self.port + db_name=self.dbname, port=self.port ) + url = base_url + f"?driver={'+'.join(self.driver.split(' '))}" + return url +