From 118dee88be60edccdda4ee89fcd88f71f85be792 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Mon, 20 Jan 2025 11:41:50 +0100 Subject: [PATCH] refactor: refactor semi ordered map --- qdrant_client/parallel_processor.py | 63 ++--------------------------- 1 file changed, 4 insertions(+), 59 deletions(-) diff --git a/qdrant_client/parallel_processor.py b/qdrant_client/parallel_processor.py index 8c43607e..8df62001 100644 --- a/qdrant_client/parallel_processor.py +++ b/qdrant_client/parallel_processor.py @@ -130,6 +130,7 @@ def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Ite pushed = 0 read = 0 for item in stream: + self.check_worker_health() if pushed - read < self.queue_size: try: out_item = self.output_queue.get_nowait() @@ -174,6 +175,9 @@ def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Ite self.input_queue.join_thread() self.output_queue.join_thread() + def semi_ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]: + return self.unordered_map(enumerate(stream), *args, **kwargs) + def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]: buffer = defaultdict(int) next_expected = 0 @@ -184,65 +188,6 @@ def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Itera yield buffer.pop(next_expected) next_expected += 1 - def semi_ordered_map( - self, stream: Iterable[Any], *args: Any, **kwargs: Any - ) -> Iterable[tuple[int, Any]]: - try: - self.start(**kwargs) - - assert self.input_queue is not None, "Input queue was not initialized" - assert self.output_queue is not None, "Output queue was not initialized" - - pushed = 0 - read = 0 - for idx, item in enumerate(stream): - self.check_worker_health() - if pushed - read < self.queue_size: - try: - out_item = self.output_queue.get_nowait() - except Empty: - out_item = None - else: - try: - out_item = self.output_queue.get(timeout=processing_timeout) - except Empty as e: - self.join_or_terminate() - raise e - - if out_item is not None: - if out_item == QueueSignals.error: - self.join_or_terminate() - raise RuntimeError("Thread unexpectedly terminated") - yield out_item - read += 1 - - self.input_queue.put((idx, item)) - pushed += 1 - - for _ in range(self.num_workers): - self.input_queue.put(QueueSignals.stop) - - while read < pushed: - self.check_worker_health() - out_item = self.output_queue.get(timeout=processing_timeout) - if out_item == QueueSignals.error: - self.join_or_terminate() - raise RuntimeError("Thread unexpectedly terminated") - yield out_item - read += 1 - finally: - assert self.input_queue is not None, "Input queue is None" - assert self.output_queue is not None, "Output queue is None" - self.join() - self.input_queue.close() - self.output_queue.close() - if self.emergency_shutdown: - self.input_queue.cancel_join_thread() - self.output_queue.cancel_join_thread() - else: - self.input_queue.join_thread() - self.output_queue.join_thread() - def check_worker_health(self) -> None: """ Checks if any worker process has terminated unexpectedly