diff --git a/modules/mssql/testcontainers/mssql/__init__.py b/modules/mssql/testcontainers/mssql/__init__.py index 6cee3681..cd966dd5 100644 --- a/modules/mssql/testcontainers/mssql/__init__.py +++ b/modules/mssql/testcontainers/mssql/__init__.py @@ -21,6 +21,11 @@ class SqlServerContainer(DbContainer): ... engine = sqlalchemy.create_engine(mssql.get_connection_url()) ... with engine.begin() as connection: ... result = connection.execute(sqlalchemy.text("select @@VERSION")) + + Notes + ----- + Requires `ODBC Driver 17 for SQL Server `_. """ def __init__( @@ -31,6 +36,7 @@ def __init__( port: int = 1433, dbname: str = "tempdb", dialect: str = "mssql+pymssql", + driver: str = "ODBC Driver 17 for SQL Server", **kwargs, ) -> None: raise_for_deprecated_parameter(kwargs, "user", "username") @@ -43,6 +49,7 @@ def __init__( self.username = username self.dbname = dbname self.dialect = dialect + self.driver = driver def _configure(self) -> None: self.with_env("SA_PASSWORD", self.password) @@ -56,6 +63,8 @@ def _connect(self) -> None: assert status == 0, "Cannot run 'SELECT 1': container is not ready" def get_connection_url(self) -> str: - return super()._create_connection_url( + base_url = super()._create_connection_url( dialect=self.dialect, username=self.username, password=self.password, dbname=self.dbname, port=self.port ) + url = base_url + f"?driver={'+'.join(self.driver.split(' '))}" + return url diff --git a/mssql/testcontainers/mssql/__init__.py b/mssql/testcontainers/mssql/__init__.py new file mode 100644 index 00000000..e28db35f --- /dev/null +++ b/mssql/testcontainers/mssql/__init__.py @@ -0,0 +1,55 @@ +from os import environ +from typing import Optional +from testcontainers.core.generic import DbContainer +from testcontainers.core.utils import raise_for_deprecated_parameter + + +class SqlServerContainer(DbContainer): + """ + Microsoft SQL Server database container. + + Example: + + .. doctest:: + + >>> import sqlalchemy + >>> from testcontainers.mssql import SqlServerContainer + + >>> with SqlServerContainer() as mssql: + ... engine = sqlalchemy.create_engine(mssql.get_connection_url()) + ... result = engine.execute(sqlalchemy.text("select @@VERSION")) + Notes + ----- + Requires `ODBC Driver 17 for SQL Server `_. + """ + + def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest", + username: str = "SA", password: Optional[str] = None, port: int = 1433, + dbname: str = "tempdb", dialect: str = 'mssql+pymssql', driver: str = "ODBC Driver 17 for SQL Server", **kwargs) -> None: + raise_for_deprecated_parameter(kwargs, "user", "username") + super(SqlServerContainer, self).__init__(image, **kwargs) + + self.port = port + self.with_exposed_ports(self.port) + + self.password = password or environ.get("SQLSERVER_PASSWORD", "1Secure*Password1") + self.username = username + self.dbname = dbname + self.dialect = dialect + self.driver = driver + + def _configure(self) -> None: + self.with_env("SA_PASSWORD", self.password) + self.with_env("SQLSERVER_USER", self.username) + self.with_env("SQLSERVER_DBNAME", self.dbname) + self.with_env("ACCEPT_EULA", 'Y') + + def get_connection_url(self) -> str: + base_url = super(SqlServerContainer, self)._create_connection_url( + dialect=self.dialect, username=self.username, password=self.password, + db_name=self.dbname, port=self.port + ) + url = base_url + f"?driver={'+'.join(self.driver.split(' '))}" + return url +