diff --git a/core/testcontainers/core/generic.py b/core/testcontainers/core/generic.py index b2cd3010d..98e2340ec 100644 --- a/core/testcontainers/core/generic.py +++ b/core/testcontainers/core/generic.py @@ -10,6 +10,9 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import tarfile +from io import BytesIO +from pathlib import Path from typing import Optional from urllib.parse import quote @@ -26,6 +29,10 @@ except ImportError: pass +SENTINEL_FOLDER = "/sentinel" +SENTINEL_FILENAME = "completed" +SENTINEL_FULLPATH = f"{SENTINEL_FOLDER}/{SENTINEL_FILENAME}" + class DbContainer(DockerContainer): """ @@ -80,4 +87,49 @@ def _configure(self) -> None: raise NotImplementedError def _transfer_seed(self) -> None: - pass + if self.seed is None: + return + src_path = Path(self.seed) + container = self.get_wrapped_container() + transfer_folder(container, src_path, self.seed_mountpoint) + transfer_file_contents(container, "Sentinel completed", SENTINEL_FOLDER) + + def override_command_for_seed(self, startup_command): + """Replace the image's command for seed purposes""" + image_cmd = get_image_cmd(self._docker.client, self.image) + cmd_full = " ".join([startup_command, image_cmd]) + command = f"""sh -c " + mkdir {SENTINEL_FOLDER}; + while [ ! -f {SENTINEL_FULLPATH} ]; + do + sleep 0.1; + done; + bash -c '{cmd_full}'" + """ + self.with_command(command) + + +def get_image_cmd(client, image): + image_info = client.api.inspect_image(image) + cmd_list: list[str] = image_info["Config"]["Cmd"] + return " ".join(cmd_list) + + +def transfer_folder(container, local_path, remote_path): + """Transfer local_path to remote_path on the given container, using put_archive""" + with BytesIO() as archive, tarfile.TarFile(fileobj=archive, mode="w") as tar: + for filename in local_path.iterdir(): + tar.add(filename.absolute(), arcname=filename.relative_to(local_path)) + archive.seek(0) + container.put_archive(remote_path, archive) + + +def transfer_file_contents(container, content_str, remote_path): + """Create a file from raw content_str to remote_path on container, via put_archive""" + with BytesIO() as archive, tarfile.TarFile(fileobj=archive, mode="w") as tar: + tarinfo = tarfile.TarInfo(name=SENTINEL_FILENAME) + content = BytesIO(bytes(content_str, encoding="utf-8")) + tarinfo.size = len(content.getvalue()) + tar.addfile(tarinfo, fileobj=content) + archive.seek(0) + container.put_archive(remote_path, archive) diff --git a/modules/mysql/testcontainers/mysql/__init__.py b/modules/mysql/testcontainers/mysql/__init__.py index 46efbcfbc..bfe7153e8 100644 --- a/modules/mysql/testcontainers/mysql/__init__.py +++ b/modules/mysql/testcontainers/mysql/__init__.py @@ -11,16 +11,15 @@ # License for the specific language governing permissions and limitations # under the License. import re -import tarfile -from io import BytesIO from os import environ -from pathlib import Path from typing import Optional from testcontainers.core.generic import DbContainer -from testcontainers.core.utils import raise_for_deprecated_parameter +from testcontainers.core.utils import raise_for_deprecated_parameter, setup_logger from testcontainers.core.waiting_utils import wait_for_logs +LOGGER = setup_logger(__name__) + class MySqlContainer(DbContainer): """ @@ -50,8 +49,10 @@ class MySqlContainer(DbContainer): automatically. .. doctest:: + >>> import sqlalchemy >>> from testcontainers.mysql import MySqlContainer + >>> with MySqlContainer(seed="../../tests/seeds/") as mysql: ... engine = sqlalchemy.create_engine(mysql.get_connection_url()) ... with engine.begin() as connection: @@ -61,6 +62,9 @@ class MySqlContainer(DbContainer): """ + seed_mountpoint: str = "/docker-entrypoint-initdb.d/" + startup_command: str = "source /usr/local/bin/docker-entrypoint.sh; _main " + def __init__( self, image: str = "mysql:latest", @@ -68,8 +72,8 @@ def __init__( root_password: Optional[str] = None, password: Optional[str] = None, dbname: Optional[str] = None, - port: int = 3306, seed: Optional[str] = None, + port: int = 3306, **kwargs, ) -> None: raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username") @@ -88,6 +92,8 @@ def __init__( if self.username == "root": self.root_password = self.password self.seed = seed + if self.seed is not None: + super().override_command_for_seed(self.startup_command) def _configure(self) -> None: self.with_env("MYSQL_ROOT_PASSWORD", self.root_password) @@ -107,14 +113,3 @@ def get_connection_url(self) -> str: return super()._create_connection_url( dialect="mysql+pymysql", username=self.username, password=self.password, dbname=self.dbname, port=self.port ) - - def _transfer_seed(self) -> None: - if self.seed is None: - return - src_path = Path(self.seed) - dest_path = "/docker-entrypoint-initdb.d/" - with BytesIO() as archive, tarfile.TarFile(fileobj=archive, mode="w") as tar: - for filename in src_path.iterdir(): - tar.add(filename.absolute(), arcname=filename.relative_to(src_path)) - archive.seek(0) - self.get_wrapped_container().put_archive(dest_path, archive) diff --git a/modules/postgres/testcontainers/postgres/__init__.py b/modules/postgres/testcontainers/postgres/__init__.py index 80baef752..8bb58d0e1 100644 --- a/modules/postgres/testcontainers/postgres/__init__.py +++ b/modules/postgres/testcontainers/postgres/__init__.py @@ -45,8 +45,29 @@ class PostgresContainer(DbContainer): ... version, = result.fetchone() >>> version 'PostgreSQL 16...' + + The optional :code:`seed` parameter enables arbitrary SQL files to be loaded. + This is perfect for schema and sample data. This works by mounting the seed to + `/docker-entrypoint-initdb./d`, which containerized Postgres are set up to load + automatically. + + .. doctest:: + + >>> from testcontainers.postgres import PostgresContainer + >>> import sqlalchemy + >>> + >>> with PostgresContainer(seed="../../tests/seeds/") as postgres: + ... engine = sqlalchemy.create_engine(postgres.get_connection_url()) + ... with engine.begin() as connection: + ... query = "select * from stuff" # Can now rely on schema/data + ... result = connection.execute(sqlalchemy.text(query)) + ... first_stuff, = result.fetchone() + """ + seed_mountpoint: str = "/docker-entrypoint-initdb.d/" + startup_command: str = "source /usr/local/bin/docker-entrypoint.sh; _main " + def __init__( self, image: str = "postgres:latest", @@ -55,6 +76,7 @@ def __init__( password: Optional[str] = None, dbname: Optional[str] = None, driver: Optional[str] = "psycopg2", + seed: Optional[str] = None, **kwargs, ) -> None: raise_for_deprecated_parameter(kwargs, "user", "username") @@ -64,6 +86,9 @@ def __init__( self.dbname: str = dbname or os.environ.get("POSTGRES_DB", "test") self.port = port self.driver = f"+{driver}" if driver else "" + self.seed = seed + if self.seed is not None: + super().override_command_for_seed(self.startup_command) self.with_exposed_ports(self.port) diff --git a/modules/postgres/tests/seeds/01-schema.sql b/modules/postgres/tests/seeds/01-schema.sql new file mode 100644 index 000000000..91e3d756a --- /dev/null +++ b/modules/postgres/tests/seeds/01-schema.sql @@ -0,0 +1,5 @@ +-- Sample SQL schema, no data +CREATE TABLE stuff ( + id integer primary key generated always as identity, + name text NOT NULL +); diff --git a/modules/postgres/tests/seeds/02-seeds.sql b/modules/postgres/tests/seeds/02-seeds.sql new file mode 100644 index 000000000..30fa938dd --- /dev/null +++ b/modules/postgres/tests/seeds/02-seeds.sql @@ -0,0 +1,4 @@ +-- Sample data, to be loaded after the schema +INSERT INTO stuff (name) +VALUES ('foo'), ('bar'), ('qux'), ('frob') +RETURNING id; diff --git a/modules/postgres/tests/test_postgres.py b/modules/postgres/tests/test_postgres.py index 38c856bf9..0944b07c5 100644 --- a/modules/postgres/tests/test_postgres.py +++ b/modules/postgres/tests/test_postgres.py @@ -38,6 +38,17 @@ def test_docker_run_postgres_with_sqlalchemy(): assert row[0].lower().startswith("postgresql 9.5") +def test_docker_run_postgres_seeds_with_sqlalchemy(): + # Avoid pytest CWD path issues + SEEDS_PATH = (Path(__file__).parent / "seeds").absolute() + postgres_container = PostgresContainer("postgres", seed=SEEDS_PATH) + with postgres_container as postgres: + engine = sqlalchemy.create_engine(postgres.get_connection_url()) + with engine.begin() as connection: + result = connection.execute(sqlalchemy.text("select * from stuff")) + assert len(list(result)) == 4, "Should have gotten all the stuff" + + def test_docker_run_postgres_with_driver_pg8000(): postgres_container = PostgresContainer("postgres:9.5", driver="pg8000") with postgres_container as postgres: