diff --git a/modules/mysql/testcontainers/mysql/__init__.py b/modules/mysql/testcontainers/mysql/__init__.py index c4449593..4c381ff9 100644 --- a/modules/mysql/testcontainers/mysql/__init__.py +++ b/modules/mysql/testcontainers/mysql/__init__.py @@ -31,14 +31,14 @@ class MySqlContainer(DbContainer): The example will spin up a MySql database to which you can connect with the credentials passed in the constructor. Alternatively, you may use the :code:`get_connection_url()` method which returns a sqlalchemy-compatible url in format - :code:`dialect+driver://username:password@host:port/database`. + :code:`mysql+dialect://username:password@host:port/database`. .. doctest:: >>> import sqlalchemy >>> from testcontainers.mysql import MySqlContainer - >>> with MySqlContainer('mysql:5.7.17') as mysql: + >>> with MySqlContainer("mysql:5.7.17", dialect="pymysql") as mysql: ... engine = sqlalchemy.create_engine(mysql.get_connection_url()) ... with engine.begin() as connection: ... result = connection.execute(sqlalchemy.text("select version()")) @@ -64,6 +64,7 @@ class MySqlContainer(DbContainer): def __init__( self, image: str = "mysql:latest", + dialect: Optional[str] = None, username: Optional[str] = None, root_password: Optional[str] = None, password: Optional[str] = None, @@ -72,6 +73,10 @@ def __init__( seed: Optional[str] = None, **kwargs, ) -> None: + if dialect is not None and dialect.startswith("mysql+"): + msg = "Please remove 'mysql+' prefix from dialect parameter" + raise ValueError(msg) + raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username") raise_for_deprecated_parameter(kwargs, "MYSQL_ROOT_PASSWORD", "root_password") raise_for_deprecated_parameter(kwargs, "MYSQL_PASSWORD", "password") @@ -85,6 +90,9 @@ def __init__( self.password = password or environ.get("MYSQL_PASSWORD", "test") self.dbname = dbname or environ.get("MYSQL_DATABASE", "test") + self.dialect = dialect or environ.get("MYSQL_DIALECT", None) + self._db_url_dialect_part = "mysql" if self.dialect is None else f"mysql+{self.dialect}" + if self.username == "root": self.root_password = self.password self.seed = seed @@ -105,7 +113,11 @@ def _connect(self) -> None: def get_connection_url(self) -> str: return super()._create_connection_url( - dialect="mysql+pymysql", username=self.username, password=self.password, dbname=self.dbname, port=self.port + dialect=self._db_url_dialect_part, + username=self.username, + password=self.password, + dbname=self.dbname, + port=self.port, ) def _transfer_seed(self) -> None: diff --git a/modules/mysql/tests/test_mysql.py b/modules/mysql/tests/test_mysql.py index af0b4491..a2d2c2ec 100644 --- a/modules/mysql/tests/test_mysql.py +++ b/modules/mysql/tests/test_mysql.py @@ -11,9 +11,14 @@ @pytest.mark.inside_docker_check def test_docker_run_mysql(): - config = MySqlContainer("mysql:8.3.0") + config = MySqlContainer("mysql:8.3.0", dialect="pymysql") with config as mysql: - engine = sqlalchemy.create_engine(mysql.get_connection_url()) + connection_url = mysql.get_connection_url() + + assert mysql.dialect == "pymysql" + assert connection_url.startswith("mysql+pymysql://") + + engine = sqlalchemy.create_engine(connection_url) with engine.begin() as connection: result = connection.execute(sqlalchemy.text("select version()")) for row in result: @@ -22,7 +27,7 @@ def test_docker_run_mysql(): @pytest.mark.skipif(is_arm(), reason="mysql container not available for ARM") def test_docker_run_legacy_mysql(): - config = MySqlContainer("mysql:5.7.44") + config = MySqlContainer("mysql:5.7.44", dialect="pymysql") with config as mysql: engine = sqlalchemy.create_engine(mysql.get_connection_url()) with engine.begin() as connection: @@ -35,7 +40,7 @@ def test_docker_run_legacy_mysql(): def test_docker_run_mysql_8_seed(): # Avoid pytest CWD path issues SEEDS_PATH = (Path(__file__).parent / "seeds").absolute() - config = MySqlContainer("mysql:8", seed=SEEDS_PATH) + config = MySqlContainer("mysql:8", dialect="pymysql", seed=str(SEEDS_PATH)) with config as mysql: engine = sqlalchemy.create_engine(mysql.get_connection_url()) with engine.begin() as connection: @@ -45,7 +50,7 @@ def test_docker_run_mysql_8_seed(): @pytest.mark.parametrize("version", ["11.3.2", "10.11.7"]) def test_docker_run_mariadb(version: str): - with MySqlContainer(f"mariadb:{version}") as mariadb: + with MySqlContainer(f"mariadb:{version}", dialect="pymysql") as mariadb: engine = sqlalchemy.create_engine(mariadb.get_connection_url()) with engine.begin() as connection: result = connection.execute(sqlalchemy.text("select version()")) @@ -55,7 +60,7 @@ def test_docker_run_mariadb(version: str): def test_docker_env_variables(): with ( - mock.patch.dict("os.environ", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"), + mock.patch.dict("os.environ", MYSQL_DIALECT="pymysql", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"), MySqlContainer("mariadb:10.6.5").with_bind_ports(3306, 32785) as container, ): url = container.get_connection_url() @@ -63,6 +68,21 @@ def test_docker_env_variables(): assert re.match(pattern, url) +@pytest.mark.parametrize( + "dialect", + [ + "mysql+pymysql", + "mysql+mariadb", + "mysql+mysqldb", + ], +) +def test_mysql_dialect_expecting_error_on_mysql_prefix(dialect: str): + match = f"Please remove *.* prefix from dialect parameter" + + with pytest.raises(ValueError, match=match): + _ = MySqlContainer("mariadb:10.6.5", dialect=dialect) + + # This is a feature in the generic DbContainer class # but it can't be tested on its own # so is tested in various database modules: @@ -75,18 +95,18 @@ def test_quoted_password(): user = "root" password = "p@$%25+0&%rd :/!=?" quoted_password = "p%40%24%2525+0%26%25rd %3A%2F%21%3D%3F" - driver = "pymysql" - with MySqlContainer("mariadb:10.6.5", username=user, password=password) as container: + dialect = "pymysql" + with MySqlContainer("mariadb:10.6.5", dialect=dialect, username=user, password=password) as container: host = container.get_container_host_ip() port = container.get_exposed_port(3306) - expected_url = f"mysql+{driver}://{user}:{quoted_password}@{host}:{port}/test" + expected_url = f"mysql+{dialect}://{user}:{quoted_password}@{host}:{port}/test" url = container.get_connection_url() assert url == expected_url with sqlalchemy.create_engine(expected_url).begin() as connection: connection.execute(sqlalchemy.text("select version()")) - raw_pass_url = f"mysql+{driver}://{user}:{password}@{host}:{port}/test" + raw_pass_url = f"mysql+{dialect}://{user}:{password}@{host}:{port}/test" with pytest.raises(Exception): with sqlalchemy.create_engine(raw_pass_url).begin() as connection: connection.execute(sqlalchemy.text("select version()"))