diff --git a/src/openeo_aggregator/partitionedjobs/zookeeper.py b/src/openeo_aggregator/partitionedjobs/zookeeper.py index c93269a5..a14230de 100644 --- a/src/openeo_aggregator/partitionedjobs/zookeeper.py +++ b/src/openeo_aggregator/partitionedjobs/zookeeper.py @@ -95,15 +95,16 @@ def obtain_new_pjob_id(self, user_id: str, initial_value: bytes = b"", attempts: """Obtain new, unique partitioned job id""" # A couple of pjob_id attempts: start with current time based name and a suffix to counter collisions (if any) base_pjob_id = "pj-" + Clock.utcnow().strftime("%Y%m%d-%H%M%S") - for pjob_id in [base_pjob_id] + [f"{base_pjob_id}-{i}" for i in range(1, attempts)]: - try: - self._client.create(path=self._path(user_id, pjob_id), value=initial_value, makepath=True) - # We obtained our unique id - return pjob_id - except NodeExistsError: - # TODO: check that NodeExistsError is thrown on existing job_ids - # TODO: add a sleep() to back off a bit? - continue + with self._connect(): + for pjob_id in [base_pjob_id] + [f"{base_pjob_id}-{i}" for i in range(1, attempts)]: + try: + self._client.create(path=self._path(user_id, pjob_id), value=initial_value, makepath=True) + # We obtained our unique id + return pjob_id + except NodeExistsError: + # TODO: check that NodeExistsError is thrown on existing job_ids + # TODO: add a sleep() to back off a bit? + continue raise PartitionedJobFailure("Too much attempts to create new pjob_id") def insert(self, user_id: str, pjob: PartitionedJob) -> str: @@ -147,12 +148,13 @@ def insert_sjob( title: Optional[str] = None, status: str = STATUS_INSERTED, ): - self._client.create( - path=self._path(user_id, pjob_id, "sjobs", sjob_id), - value=self.serialize(process_graph=subjob.process_graph, backend_id=subjob.backend_id, title=title), - makepath=True, - ) - self.set_sjob_status(user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=status, create=True) + with self._connect(): + self._client.create( + path=self._path(user_id, pjob_id, "sjobs", sjob_id), + value=self.serialize(process_graph=subjob.process_graph, backend_id=subjob.backend_id, title=title), + makepath=True, + ) + self.set_sjob_status(user_id=user_id, pjob_id=pjob_id, sjob_id=sjob_id, status=status, create=True) def get_pjob_metadata(self, user_id: str, pjob_id: str) -> dict: """Get metadata of partitioned job, given by storage id.""" diff --git a/src/openeo_aggregator/testing.py b/src/openeo_aggregator/testing.py index 77e3df59..54db1bfb 100644 --- a/src/openeo_aggregator/testing.py +++ b/src/openeo_aggregator/testing.py @@ -32,7 +32,12 @@ def stop(self): assert self.state == "open" self.state = "closed" + def _assert_open(self): + if not self.state == "open": + raise kazoo.exceptions.ConnectionClosedError("Connection has been closed") + def create(self, path: str, value, makepath: bool = False): + self._assert_open() if path in self.data: raise kazoo.exceptions.NodeExistsError() parent = str(pathlib.Path(path).parent) @@ -44,20 +49,24 @@ def create(self, path: str, value, makepath: bool = False): self.data[path] = value def exists(self, path): + self._assert_open() return path in self.data def get(self, path): + self._assert_open() if path not in self.data: raise kazoo.exceptions.NoNodeError() return (self.data[path], None) def get_children(self, path): + self._assert_open() if path not in self.data: raise kazoo.exceptions.NoNodeError() parent = path.split("/") return [p.split("/")[-1] for p in self.data if p.split("/")[:-1] == parent] def set(self, path, value, version=-1): + self._assert_open() if path not in self.data: raise kazoo.exceptions.NoNodeError() self.data[path] = value