From d5b855323be06f8d1395dd480a347f0efef75703 Mon Sep 17 00:00:00 2001 From: Vemund Santi Date: Wed, 3 Apr 2024 14:40:39 +0200 Subject: [PATCH] fix(core): Improve typing for common container usage scenarios (#523) Improves type hints for type checking in a common use cases: ```python with MySqlContainer("mysql:8").with_env("some", "value") as mysql: url = mysql.get_connection_url() # get_connection_url would previously be an unknown member here ``` And, also improved type hinting for the custom `DockerClient`'s `run` command, where the linter no longer reports an error due to missing parameter types: ```python DockerClient.run("nginx") # Previously this would report "Argument missing for parameter "image" ``` --- core/testcontainers/core/container.py | 24 ++++++++++++----------- core/testcontainers/core/docker_client.py | 17 ++++++++++++++-- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/core/testcontainers/core/container.py b/core/testcontainers/core/container.py index 3e1e1ba1..559a4ffe 100644 --- a/core/testcontainers/core/container.py +++ b/core/testcontainers/core/container.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional import docker.errors +from typing_extensions import Self from testcontainers.core.config import ( RYUK_DISABLED, @@ -53,29 +54,29 @@ def __init__( self._name = None self._kwargs = kwargs - def with_env(self, key: str, value: str) -> "DockerContainer": + def with_env(self, key: str, value: str) -> Self: self.env[key] = value return self - def with_bind_ports(self, container: int, host: Optional[int] = None) -> "DockerContainer": + def with_bind_ports(self, container: int, host: Optional[int] = None) -> Self: self.ports[container] = host return self - def with_exposed_ports(self, *ports: int) -> "DockerContainer": + def with_exposed_ports(self, *ports: int) -> Self: for port in ports: self.ports[port] = None return self - def with_kwargs(self, **kwargs) -> "DockerContainer": + def with_kwargs(self, **kwargs) -> Self: self._kwargs = kwargs return self - def maybe_emulate_amd64(self) -> "DockerContainer": + def maybe_emulate_amd64(self) -> Self: if is_arm(): return self.with_kwargs(platform="linux/amd64") return self - def start(self): + def start(self) -> Self: if not RYUK_DISABLED and self.image != RYUK_IMAGE: logger.debug("Creating Ryuk container") Reaper.get_instance() @@ -95,10 +96,11 @@ def start(self): return self def stop(self, force=True, delete_volume=True) -> None: - self._container.remove(force=force, v=delete_volume) + if self._container: + self._container.remove(force=force, v=delete_volume) self.get_docker_client().client.close() - def __enter__(self): + def __enter__(self) -> Self: return self.start() def __exit__(self, exc_type, exc_val, exc_tb) -> None: @@ -138,15 +140,15 @@ def get_exposed_port(self, port: int) -> str: return port return mapped_port - def with_command(self, command: str) -> "DockerContainer": + def with_command(self, command: str) -> Self: self._command = command return self - def with_name(self, name: str) -> "DockerContainer": + def with_name(self, name: str) -> Self: self._name = name return self - def with_volume_mapping(self, host: str, container: str, mode: str = "ro") -> "DockerContainer": + def with_volume_mapping(self, host: str, container: str, mode: str = "ro") -> Self: mapping = {"bind": container, "mode": mode} self.volumes[host] = mapping return self diff --git a/core/testcontainers/core/docker_client.py b/core/testcontainers/core/docker_client.py index c72c48e0..22a9a4ef 100644 --- a/core/testcontainers/core/docker_client.py +++ b/core/testcontainers/core/docker_client.py @@ -18,10 +18,11 @@ import urllib.parse from os.path import exists from pathlib import Path -from typing import Optional, Union +from typing import Callable, Optional, TypeVar, Union import docker from docker.models.containers import Container, ContainerCollection +from typing_extensions import ParamSpec from testcontainers.core.labels import SESSION_ID, create_labels from testcontainers.core.utils import default_gateway_ip, inside_container, setup_logger @@ -30,6 +31,18 @@ TC_FILE = ".testcontainers.properties" TC_GLOBAL = Path.home() / TC_FILE +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _wrapped_container_collection(function: Callable[_P, _T]) -> Callable[_P, _T]: + + @ft.wraps(ContainerCollection.run) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + return function(*args, **kwargs) + + return wrapper + class DockerClient: """ @@ -48,7 +61,7 @@ def __init__(self, **kwargs) -> None: self.client.api.headers["x-tc-sid"] = SESSION_ID self.client.api.headers["User-Agent"] = "tc-python/" + importlib.metadata.version("testcontainers") - @ft.wraps(ContainerCollection.run) + @_wrapped_container_collection def run( self, image: str,