Skip to content

Commit

Permalink
fix(mysql): add dialect parameter instead of hardcoded mysql dialect (#…
Browse files Browse the repository at this point in the history
…739)

closes
#727

* add parameter `dialect`;
* tests fixing and add some assertions
  • Loading branch information
nightblure authored Dec 12, 2024
1 parent 3436cbf commit 8d77bd3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
18 changes: 15 additions & 3 deletions modules/mysql/testcontainers/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ class MySqlContainer(DbContainer):
The example will spin up a MySql database to which you can connect with the credentials
passed in the constructor. Alternatively, you may use the :code:`get_connection_url()`
method which returns a sqlalchemy-compatible url in format
:code:`dialect+driver://username:password@host:port/database`.
:code:`mysql+dialect://username:password@host:port/database`.
.. doctest::
>>> import sqlalchemy
>>> from testcontainers.mysql import MySqlContainer
>>> with MySqlContainer('mysql:5.7.17') as mysql:
>>> with MySqlContainer("mysql:5.7.17", dialect="pymysql") as mysql:
... engine = sqlalchemy.create_engine(mysql.get_connection_url())
... with engine.begin() as connection:
... result = connection.execute(sqlalchemy.text("select version()"))
Expand All @@ -64,6 +64,7 @@ class MySqlContainer(DbContainer):
def __init__(
self,
image: str = "mysql:latest",
dialect: Optional[str] = None,
username: Optional[str] = None,
root_password: Optional[str] = None,
password: Optional[str] = None,
Expand All @@ -72,6 +73,10 @@ def __init__(
seed: Optional[str] = None,
**kwargs,
) -> None:
if dialect is not None and dialect.startswith("mysql+"):
msg = "Please remove 'mysql+' prefix from dialect parameter"
raise ValueError(msg)

raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username")
raise_for_deprecated_parameter(kwargs, "MYSQL_ROOT_PASSWORD", "root_password")
raise_for_deprecated_parameter(kwargs, "MYSQL_PASSWORD", "password")
Expand All @@ -85,6 +90,9 @@ def __init__(
self.password = password or environ.get("MYSQL_PASSWORD", "test")
self.dbname = dbname or environ.get("MYSQL_DATABASE", "test")

self.dialect = dialect or environ.get("MYSQL_DIALECT", None)
self._db_url_dialect_part = "mysql" if self.dialect is None else f"mysql+{self.dialect}"

if self.username == "root":
self.root_password = self.password
self.seed = seed
Expand All @@ -105,7 +113,11 @@ def _connect(self) -> None:

def get_connection_url(self) -> str:
return super()._create_connection_url(
dialect="mysql+pymysql", username=self.username, password=self.password, dbname=self.dbname, port=self.port
dialect=self._db_url_dialect_part,
username=self.username,
password=self.password,
dbname=self.dbname,
port=self.port,
)

def _transfer_seed(self) -> None:
Expand Down
40 changes: 30 additions & 10 deletions modules/mysql/tests/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@

@pytest.mark.inside_docker_check
def test_docker_run_mysql():
config = MySqlContainer("mysql:8.3.0")
config = MySqlContainer("mysql:8.3.0", dialect="pymysql")
with config as mysql:
engine = sqlalchemy.create_engine(mysql.get_connection_url())
connection_url = mysql.get_connection_url()

assert mysql.dialect == "pymysql"
assert connection_url.startswith("mysql+pymysql://")

engine = sqlalchemy.create_engine(connection_url)
with engine.begin() as connection:
result = connection.execute(sqlalchemy.text("select version()"))
for row in result:
Expand All @@ -22,7 +27,7 @@ def test_docker_run_mysql():

@pytest.mark.skipif(is_arm(), reason="mysql container not available for ARM")
def test_docker_run_legacy_mysql():
config = MySqlContainer("mysql:5.7.44")
config = MySqlContainer("mysql:5.7.44", dialect="pymysql")
with config as mysql:
engine = sqlalchemy.create_engine(mysql.get_connection_url())
with engine.begin() as connection:
Expand All @@ -35,7 +40,7 @@ def test_docker_run_legacy_mysql():
def test_docker_run_mysql_8_seed():
# Avoid pytest CWD path issues
SEEDS_PATH = (Path(__file__).parent / "seeds").absolute()
config = MySqlContainer("mysql:8", seed=SEEDS_PATH)
config = MySqlContainer("mysql:8", dialect="pymysql", seed=str(SEEDS_PATH))
with config as mysql:
engine = sqlalchemy.create_engine(mysql.get_connection_url())
with engine.begin() as connection:
Expand All @@ -45,7 +50,7 @@ def test_docker_run_mysql_8_seed():

@pytest.mark.parametrize("version", ["11.3.2", "10.11.7"])
def test_docker_run_mariadb(version: str):
with MySqlContainer(f"mariadb:{version}") as mariadb:
with MySqlContainer(f"mariadb:{version}", dialect="pymysql") as mariadb:
engine = sqlalchemy.create_engine(mariadb.get_connection_url())
with engine.begin() as connection:
result = connection.execute(sqlalchemy.text("select version()"))
Expand All @@ -55,14 +60,29 @@ def test_docker_run_mariadb(version: str):

def test_docker_env_variables():
with (
mock.patch.dict("os.environ", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"),
mock.patch.dict("os.environ", MYSQL_DIALECT="pymysql", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"),
MySqlContainer("mariadb:10.6.5").with_bind_ports(3306, 32785) as container,
):
url = container.get_connection_url()
pattern = r"mysql\+pymysql:\/\/demo:test@[\w,.]+:(3306|32785)\/custom_db"
assert re.match(pattern, url)


@pytest.mark.parametrize(
"dialect",
[
"mysql+pymysql",
"mysql+mariadb",
"mysql+mysqldb",
],
)
def test_mysql_dialect_expecting_error_on_mysql_prefix(dialect: str):
match = f"Please remove *.* prefix from dialect parameter"

with pytest.raises(ValueError, match=match):
_ = MySqlContainer("mariadb:10.6.5", dialect=dialect)


# This is a feature in the generic DbContainer class
# but it can't be tested on its own
# so is tested in various database modules:
Expand All @@ -75,18 +95,18 @@ def test_quoted_password():
user = "root"
password = "p@$%25+0&%rd :/!=?"
quoted_password = "p%40%24%2525+0%26%25rd %3A%2F%21%3D%3F"
driver = "pymysql"
with MySqlContainer("mariadb:10.6.5", username=user, password=password) as container:
dialect = "pymysql"
with MySqlContainer("mariadb:10.6.5", dialect=dialect, username=user, password=password) as container:
host = container.get_container_host_ip()
port = container.get_exposed_port(3306)
expected_url = f"mysql+{driver}://{user}:{quoted_password}@{host}:{port}/test"
expected_url = f"mysql+{dialect}://{user}:{quoted_password}@{host}:{port}/test"
url = container.get_connection_url()
assert url == expected_url

with sqlalchemy.create_engine(expected_url).begin() as connection:
connection.execute(sqlalchemy.text("select version()"))

raw_pass_url = f"mysql+{driver}://{user}:{password}@{host}:{port}/test"
raw_pass_url = f"mysql+{dialect}://{user}:{password}@{host}:{port}/test"
with pytest.raises(Exception):
with sqlalchemy.create_engine(raw_pass_url).begin() as connection:
connection.execute(sqlalchemy.text("select version()"))

0 comments on commit 8d77bd3

Please sign in to comment.