Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(postgres): Add seed feature to postgres #576

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion core/testcontainers/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import tarfile
from io import BytesIO
from pathlib import Path

Check warning on line 15 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L13-L15

Added lines #L13 - L15 were not covered by tests
from typing import Optional
from urllib.parse import quote

Expand All @@ -26,6 +29,10 @@
except ImportError:
pass

SENTINEL_FOLDER = "/sentinel"
SENTINEL_FILENAME = "completed"
SENTINEL_FULLPATH = f"{SENTINEL_FOLDER}/{SENTINEL_FILENAME}"

Check warning on line 34 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L32-L34

Added lines #L32 - L34 were not covered by tests


class DbContainer(DockerContainer):
"""
Expand Down Expand Up @@ -80,4 +87,49 @@
raise NotImplementedError

def _transfer_seed(self) -> None:
pass
if self.seed is None:
return
src_path = Path(self.seed)
container = self.get_wrapped_container()
transfer_folder(container, src_path, self.seed_mountpoint)
transfer_file_contents(container, "Sentinel completed", SENTINEL_FOLDER)

Check warning on line 95 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L91-L95

Added lines #L91 - L95 were not covered by tests

def override_command_for_seed(self, startup_command):

Check warning on line 97 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L97

Added line #L97 was not covered by tests
"""Replace the image's command for seed purposes"""
image_cmd = get_image_cmd(self._docker.client, self.image)
cmd_full = " ".join([startup_command, image_cmd])
command = f"""sh -c "

Check warning on line 101 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L99-L101

Added lines #L99 - L101 were not covered by tests
mkdir {SENTINEL_FOLDER};
while [ ! -f {SENTINEL_FULLPATH} ];
do
sleep 0.1;
done;
bash -c '{cmd_full}'"
"""
self.with_command(command)

Check warning on line 109 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L109

Added line #L109 was not covered by tests


def get_image_cmd(client, image):
image_info = client.api.inspect_image(image)
cmd_list: list[str] = image_info["Config"]["Cmd"]
return " ".join(cmd_list)

Check warning on line 115 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L112-L115

Added lines #L112 - L115 were not covered by tests


def transfer_folder(container, local_path, remote_path):

Check warning on line 118 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L118

Added line #L118 was not covered by tests
"""Transfer local_path to remote_path on the given container, using put_archive"""
with BytesIO() as archive, tarfile.TarFile(fileobj=archive, mode="w") as tar:
for filename in local_path.iterdir():
tar.add(filename.absolute(), arcname=filename.relative_to(local_path))
archive.seek(0)
container.put_archive(remote_path, archive)

Check warning on line 124 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L122-L124

Added lines #L122 - L124 were not covered by tests


def transfer_file_contents(container, content_str, remote_path):

Check warning on line 127 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L127

Added line #L127 was not covered by tests
"""Create a file from raw content_str to remote_path on container, via put_archive"""
with BytesIO() as archive, tarfile.TarFile(fileobj=archive, mode="w") as tar:
tarinfo = tarfile.TarInfo(name=SENTINEL_FILENAME)
content = BytesIO(bytes(content_str, encoding="utf-8"))
tarinfo.size = len(content.getvalue())
tar.addfile(tarinfo, fileobj=content)
archive.seek(0)
container.put_archive(remote_path, archive)

Check warning on line 135 in core/testcontainers/core/generic.py

View check run for this annotation

Codecov / codecov/patch

core/testcontainers/core/generic.py#L130-L135

Added lines #L130 - L135 were not covered by tests
27 changes: 11 additions & 16 deletions modules/mysql/testcontainers/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
# License for the specific language governing permissions and limitations
# under the License.
import re
import tarfile
from io import BytesIO
from os import environ
from pathlib import Path
from typing import Optional

from testcontainers.core.generic import DbContainer
from testcontainers.core.utils import raise_for_deprecated_parameter
from testcontainers.core.utils import raise_for_deprecated_parameter, setup_logger
from testcontainers.core.waiting_utils import wait_for_logs

LOGGER = setup_logger(__name__)


class MySqlContainer(DbContainer):
"""
Expand Down Expand Up @@ -50,8 +49,10 @@ class MySqlContainer(DbContainer):
automatically.

.. doctest::

>>> import sqlalchemy
>>> from testcontainers.mysql import MySqlContainer

>>> with MySqlContainer(seed="../../tests/seeds/") as mysql:
... engine = sqlalchemy.create_engine(mysql.get_connection_url())
... with engine.begin() as connection:
Expand All @@ -61,15 +62,18 @@ class MySqlContainer(DbContainer):

"""

seed_mountpoint: str = "/docker-entrypoint-initdb.d/"
startup_command: str = "source /usr/local/bin/docker-entrypoint.sh; _main "

def __init__(
self,
image: str = "mysql:latest",
username: Optional[str] = None,
root_password: Optional[str] = None,
password: Optional[str] = None,
dbname: Optional[str] = None,
port: int = 3306,
seed: Optional[str] = None,
port: int = 3306,
**kwargs,
) -> None:
raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username")
Expand All @@ -88,6 +92,8 @@ def __init__(
if self.username == "root":
self.root_password = self.password
self.seed = seed
if self.seed is not None:
super().override_command_for_seed(self.startup_command)

def _configure(self) -> None:
self.with_env("MYSQL_ROOT_PASSWORD", self.root_password)
Expand All @@ -107,14 +113,3 @@ def get_connection_url(self) -> str:
return super()._create_connection_url(
dialect="mysql+pymysql", username=self.username, password=self.password, dbname=self.dbname, port=self.port
)

def _transfer_seed(self) -> None:
if self.seed is None:
return
src_path = Path(self.seed)
dest_path = "/docker-entrypoint-initdb.d/"
with BytesIO() as archive, tarfile.TarFile(fileobj=archive, mode="w") as tar:
for filename in src_path.iterdir():
tar.add(filename.absolute(), arcname=filename.relative_to(src_path))
archive.seek(0)
self.get_wrapped_container().put_archive(dest_path, archive)
25 changes: 25 additions & 0 deletions modules/postgres/testcontainers/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,29 @@ class PostgresContainer(DbContainer):
... version, = result.fetchone()
>>> version
'PostgreSQL 16...'

The optional :code:`seed` parameter enables arbitrary SQL files to be loaded.
This is perfect for schema and sample data. This works by mounting the seed to
`/docker-entrypoint-initdb./d`, which containerized Postgres are set up to load
automatically.

.. doctest::

>>> from testcontainers.postgres import PostgresContainer
>>> import sqlalchemy
>>>
>>> with PostgresContainer(seed="../../tests/seeds/") as postgres:
... engine = sqlalchemy.create_engine(postgres.get_connection_url())
... with engine.begin() as connection:
... query = "select * from stuff" # Can now rely on schema/data
... result = connection.execute(sqlalchemy.text(query))
... first_stuff, = result.fetchone()

"""

seed_mountpoint: str = "/docker-entrypoint-initdb.d/"
startup_command: str = "source /usr/local/bin/docker-entrypoint.sh; _main "

def __init__(
self,
image: str = "postgres:latest",
Expand All @@ -55,6 +76,7 @@ def __init__(
password: Optional[str] = None,
dbname: Optional[str] = None,
driver: Optional[str] = "psycopg2",
seed: Optional[str] = None,
**kwargs,
) -> None:
raise_for_deprecated_parameter(kwargs, "user", "username")
Expand All @@ -64,6 +86,9 @@ def __init__(
self.dbname: str = dbname or os.environ.get("POSTGRES_DB", "test")
self.port = port
self.driver = f"+{driver}" if driver else ""
self.seed = seed
if self.seed is not None:
super().override_command_for_seed(self.startup_command)

self.with_exposed_ports(self.port)

Expand Down
5 changes: 5 additions & 0 deletions modules/postgres/tests/seeds/01-schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- Sample SQL schema, no data
CREATE TABLE stuff (
id integer primary key generated always as identity,
name text NOT NULL
);
4 changes: 4 additions & 0 deletions modules/postgres/tests/seeds/02-seeds.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Sample data, to be loaded after the schema
INSERT INTO stuff (name)
VALUES ('foo'), ('bar'), ('qux'), ('frob')
RETURNING id;
11 changes: 11 additions & 0 deletions modules/postgres/tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def test_docker_run_postgres_with_sqlalchemy():
assert row[0].lower().startswith("postgresql 9.5")


def test_docker_run_postgres_seeds_with_sqlalchemy():
# Avoid pytest CWD path issues
SEEDS_PATH = (Path(__file__).parent / "seeds").absolute()
postgres_container = PostgresContainer("postgres", seed=SEEDS_PATH)
with postgres_container as postgres:
engine = sqlalchemy.create_engine(postgres.get_connection_url())
with engine.begin() as connection:
result = connection.execute(sqlalchemy.text("select * from stuff"))
assert len(list(result)) == 4, "Should have gotten all the stuff"


def test_docker_run_postgres_with_driver_pg8000():
postgres_container = PostgresContainer("postgres:9.5", driver="pg8000")
with postgres_container as postgres:
Expand Down
Loading