diff --git a/modules/trino/testcontainers/trino/__init__.py b/modules/trino/testcontainers/trino/__init__.py index 97e3f9de..b9197ba2 100644 --- a/modules/trino/testcontainers/trino/__init__.py +++ b/modules/trino/testcontainers/trino/__init__.py @@ -11,6 +11,7 @@ # License for the specific language governing permissions and limitations # under the License. import re +import time from testcontainers.core.config import testcontainers_config as c from testcontainers.core.generic import DbContainer @@ -44,10 +45,16 @@ def _connect(self) -> None: port=self.get_exposed_port(self.port), user=self.user, ) - cur = conn.cursor() - cur.execute("SELECT 1") - cur.fetchall() - conn.close() + deadline = time.time() + c.max_tries + while time.time() < deadline: + try: + cur = conn.cursor() + cur.execute("SELECT * FROM tpch.tiny.nation LIMIT 1") + cur.fetchall() + return + except Exception: + time.sleep(c.sleep_time) + raise TimeoutError(f"Trino did not start within {c.max_tries:.3f} seconds") def get_connection_url(self): return f"trino://{self.user}@{self.get_container_host_ip()}:{self.port}"