Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mysql): add dialect parameter instead of hardcoded mysql dialect #739

Merged
merged 14 commits into from
Dec 12, 2024
18 changes: 15 additions & 3 deletions modules/mysql/testcontainers/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ class MySqlContainer(DbContainer):
The example will spin up a MySql database to which you can connect with the credentials
passed in the constructor. Alternatively, you may use the :code:`get_connection_url()`
method which returns a sqlalchemy-compatible url in format
:code:`dialect+driver://username:password@host:port/database`.
:code:`mysql+dialect://username:password@host:port/database`.

.. doctest::

>>> import sqlalchemy
>>> from testcontainers.mysql import MySqlContainer

>>> with MySqlContainer('mysql:5.7.17') as mysql:
>>> with MySqlContainer("mysql:5.7.17", dialect="pymysql") as mysql:
... engine = sqlalchemy.create_engine(mysql.get_connection_url())
... with engine.begin() as connection:
... result = connection.execute(sqlalchemy.text("select version()"))
Expand All @@ -64,6 +64,7 @@ class MySqlContainer(DbContainer):
def __init__(
self,
image: str = "mysql:latest",
dialect: Optional[str] = None,
username: Optional[str] = None,
root_password: Optional[str] = None,
password: Optional[str] = None,
Expand All @@ -72,6 +73,10 @@ def __init__(
seed: Optional[str] = None,
**kwargs,
) -> None:
if dialect is not None and dialect.startswith("mysql+"):
msg = "Please remove 'mysql+' prefix from dialect parameter"
raise ValueError(msg)

raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username")
raise_for_deprecated_parameter(kwargs, "MYSQL_ROOT_PASSWORD", "root_password")
raise_for_deprecated_parameter(kwargs, "MYSQL_PASSWORD", "password")
Expand All @@ -85,6 +90,9 @@ def __init__(
self.password = password or environ.get("MYSQL_PASSWORD", "test")
self.dbname = dbname or environ.get("MYSQL_DATABASE", "test")

self.dialect = dialect or environ.get("MYSQL_DIALECT", None)
self._db_url_dialect_part = "mysql" if self.dialect is None else f"mysql+{self.dialect}"

if self.username == "root":
self.root_password = self.password
self.seed = seed
Expand All @@ -105,7 +113,11 @@ def _connect(self) -> None:

def get_connection_url(self) -> str:
return super()._create_connection_url(
dialect="mysql+pymysql", username=self.username, password=self.password, dbname=self.dbname, port=self.port
dialect=self._db_url_dialect_part,
username=self.username,
password=self.password,
dbname=self.dbname,
port=self.port,
)

def _transfer_seed(self) -> None:
Expand Down
40 changes: 30 additions & 10 deletions modules/mysql/tests/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@

@pytest.mark.inside_docker_check
def test_docker_run_mysql():
config = MySqlContainer("mysql:8.3.0")
config = MySqlContainer("mysql:8.3.0", dialect="pymysql")
with config as mysql:
engine = sqlalchemy.create_engine(mysql.get_connection_url())
connection_url = mysql.get_connection_url()

assert mysql.dialect == "pymysql"
assert connection_url.startswith("mysql+pymysql://")

engine = sqlalchemy.create_engine(connection_url)
with engine.begin() as connection:
result = connection.execute(sqlalchemy.text("select version()"))
for row in result:
Expand All @@ -22,7 +27,7 @@ def test_docker_run_mysql():

@pytest.mark.skipif(is_arm(), reason="mysql container not available for ARM")
def test_docker_run_legacy_mysql():
config = MySqlContainer("mysql:5.7.44")
config = MySqlContainer("mysql:5.7.44", dialect="pymysql")
with config as mysql:
engine = sqlalchemy.create_engine(mysql.get_connection_url())
with engine.begin() as connection:
Expand All @@ -35,7 +40,7 @@ def test_docker_run_legacy_mysql():
def test_docker_run_mysql_8_seed():
# Avoid pytest CWD path issues
SEEDS_PATH = (Path(__file__).parent / "seeds").absolute()
config = MySqlContainer("mysql:8", seed=SEEDS_PATH)
config = MySqlContainer("mysql:8", dialect="pymysql", seed=str(SEEDS_PATH))
with config as mysql:
engine = sqlalchemy.create_engine(mysql.get_connection_url())
with engine.begin() as connection:
Expand All @@ -45,7 +50,7 @@ def test_docker_run_mysql_8_seed():

@pytest.mark.parametrize("version", ["11.3.2", "10.11.7"])
def test_docker_run_mariadb(version: str):
with MySqlContainer(f"mariadb:{version}") as mariadb:
with MySqlContainer(f"mariadb:{version}", dialect="pymysql") as mariadb:
engine = sqlalchemy.create_engine(mariadb.get_connection_url())
with engine.begin() as connection:
result = connection.execute(sqlalchemy.text("select version()"))
Expand All @@ -55,14 +60,29 @@ def test_docker_run_mariadb(version: str):

def test_docker_env_variables():
with (
mock.patch.dict("os.environ", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"),
mock.patch.dict("os.environ", MYSQL_DIALECT="pymysql", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"),
MySqlContainer("mariadb:10.6.5").with_bind_ports(3306, 32785) as container,
):
url = container.get_connection_url()
pattern = r"mysql\+pymysql:\/\/demo:test@[\w,.]+:(3306|32785)\/custom_db"
assert re.match(pattern, url)


@pytest.mark.parametrize(
"dialect",
[
"mysql+pymysql",
"mysql+mariadb",
"mysql+mysqldb",
],
)
def test_mysql_dialect_expecting_error_on_mysql_prefix(dialect: str):
match = f"Please remove *.* prefix from dialect parameter"

with pytest.raises(ValueError, match=match):
_ = MySqlContainer("mariadb:10.6.5", dialect=dialect)


# This is a feature in the generic DbContainer class
# but it can't be tested on its own
# so is tested in various database modules:
Expand All @@ -75,18 +95,18 @@ def test_quoted_password():
user = "root"
password = "p@$%25+0&%rd :/!=?"
quoted_password = "p%40%24%2525+0%26%25rd %3A%2F%21%3D%3F"
driver = "pymysql"
with MySqlContainer("mariadb:10.6.5", username=user, password=password) as container:
dialect = "pymysql"
with MySqlContainer("mariadb:10.6.5", dialect=dialect, username=user, password=password) as container:
host = container.get_container_host_ip()
port = container.get_exposed_port(3306)
expected_url = f"mysql+{driver}://{user}:{quoted_password}@{host}:{port}/test"
expected_url = f"mysql+{dialect}://{user}:{quoted_password}@{host}:{port}/test"
url = container.get_connection_url()
assert url == expected_url

with sqlalchemy.create_engine(expected_url).begin() as connection:
connection.execute(sqlalchemy.text("select version()"))

raw_pass_url = f"mysql+{driver}://{user}:{password}@{host}:{port}/test"
raw_pass_url = f"mysql+{dialect}://{user}:{password}@{host}:{port}/test"
with pytest.raises(Exception):
with sqlalchemy.create_engine(raw_pass_url).begin() as connection:
connection.execute(sqlalchemy.text("select version()"))