diff --git a/core/testcontainers/core/generic.py b/core/testcontainers/core/generic.py index 515c2831..6dd635e6 100644 --- a/core/testcontainers/core/generic.py +++ b/core/testcontainers/core/generic.py @@ -70,8 +70,12 @@ def _create_connection_url( def start(self) -> "DbContainer": self._configure() super().start() + self._transfer_seed() self._connect() return self def _configure(self) -> None: raise NotImplementedError + + def _transfer_seed(self) -> None: + pass diff --git a/modules/mysql/testcontainers/mysql/__init__.py b/modules/mysql/testcontainers/mysql/__init__.py index 1b0751bc..46efbcfb 100644 --- a/modules/mysql/testcontainers/mysql/__init__.py +++ b/modules/mysql/testcontainers/mysql/__init__.py @@ -11,7 +11,10 @@ # License for the specific language governing permissions and limitations # under the License. import re +import tarfile +from io import BytesIO from os import environ +from pathlib import Path from typing import Optional from testcontainers.core.generic import DbContainer @@ -40,6 +43,22 @@ class MySqlContainer(DbContainer): ... with engine.begin() as connection: ... result = connection.execute(sqlalchemy.text("select version()")) ... version, = result.fetchone() + + The optional :code:`seed` parameter enables arbitrary SQL files to be loaded. + This is perfect for schema and sample data. This works by mounting the seed to + `/docker-entrypoint-initdb./d`, which containerized MySQL are set up to load + automatically. + + .. doctest:: + >>> import sqlalchemy + >>> from testcontainers.mysql import MySqlContainer + >>> with MySqlContainer(seed="../../tests/seeds/") as mysql: + ... engine = sqlalchemy.create_engine(mysql.get_connection_url()) + ... with engine.begin() as connection: + ... query = "select * from stuff" # Can now rely on schema/data + ... result = connection.execute(sqlalchemy.text(query)) + ... first_stuff, = result.fetchone() + """ def __init__( @@ -50,6 +69,7 @@ def __init__( password: Optional[str] = None, dbname: Optional[str] = None, port: int = 3306, + seed: Optional[str] = None, **kwargs, ) -> None: raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username") @@ -67,6 +87,7 @@ def __init__( if self.username == "root": self.root_password = self.password + self.seed = seed def _configure(self) -> None: self.with_env("MYSQL_ROOT_PASSWORD", self.root_password) @@ -86,3 +107,14 @@ def get_connection_url(self) -> str: return super()._create_connection_url( dialect="mysql+pymysql", username=self.username, password=self.password, dbname=self.dbname, port=self.port ) + + def _transfer_seed(self) -> None: + if self.seed is None: + return + src_path = Path(self.seed) + dest_path = "/docker-entrypoint-initdb.d/" + with BytesIO() as archive, tarfile.TarFile(fileobj=archive, mode="w") as tar: + for filename in src_path.iterdir(): + tar.add(filename.absolute(), arcname=filename.relative_to(src_path)) + archive.seek(0) + self.get_wrapped_container().put_archive(dest_path, archive) diff --git a/modules/mysql/tests/seeds/01-schema.sql b/modules/mysql/tests/seeds/01-schema.sql new file mode 100644 index 00000000..ea398244 --- /dev/null +++ b/modules/mysql/tests/seeds/01-schema.sql @@ -0,0 +1,6 @@ +-- Sample SQL schema, no data +CREATE TABLE `stuff` ( + `id` mediumint NOT NULL AUTO_INCREMENT, + `name` VARCHAR(63) NOT NULL, + PRIMARY KEY (`id`) +); diff --git a/modules/mysql/tests/seeds/02-seeds.sql b/modules/mysql/tests/seeds/02-seeds.sql new file mode 100644 index 00000000..7ce78903 --- /dev/null +++ b/modules/mysql/tests/seeds/02-seeds.sql @@ -0,0 +1,3 @@ +-- Sample data, to be loaded after the schema +INSERT INTO stuff (name) +VALUES ("foo"), ("bar"), ("qux"), ("frob"); diff --git a/modules/mysql/tests/test_mysql.py b/modules/mysql/tests/test_mysql.py index ee1e2b45..847f99df 100644 --- a/modules/mysql/tests/test_mysql.py +++ b/modules/mysql/tests/test_mysql.py @@ -1,3 +1,4 @@ +from pathlib import Path import re from unittest import mock @@ -29,6 +30,18 @@ def test_docker_run_legacy_mysql(): assert row[0].startswith("5.7.44") +@pytest.mark.skipif(is_arm(), reason="mysql container not available for ARM") +def test_docker_run_mysql_8_seed(): + # Avoid pytest CWD path issues + SEEDS_PATH = (Path(__file__).parent / "seeds").absolute() + config = MySqlContainer("mysql:8", seed=SEEDS_PATH) + with config as mysql: + engine = sqlalchemy.create_engine(mysql.get_connection_url()) + with engine.begin() as connection: + result = connection.execute(sqlalchemy.text("select * from stuff")) + assert len(list(result)) == 4, "Should have gotten all the stuff" + + @pytest.mark.parametrize("version", ["11.3.2", "10.11.7"]) def test_docker_run_mariadb(version: str): with MySqlContainer(f"mariadb:{version}") as mariadb: