Skip to content

Commit

Permalink
refactor: refactor semi ordered map
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Jan 20, 2025
1 parent 3085187 commit 118dee8
Showing 1 changed file with 4 additions and 59 deletions.
63 changes: 4 additions & 59 deletions qdrant_client/parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Ite
pushed = 0
read = 0
for item in stream:
self.check_worker_health()
if pushed - read < self.queue_size:
try:
out_item = self.output_queue.get_nowait()
Expand Down Expand Up @@ -174,6 +175,9 @@ def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Ite
self.input_queue.join_thread()
self.output_queue.join_thread()

def semi_ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
return self.unordered_map(enumerate(stream), *args, **kwargs)

def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
buffer = defaultdict(int)
next_expected = 0
Expand All @@ -184,65 +188,6 @@ def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Itera
yield buffer.pop(next_expected)
next_expected += 1

def semi_ordered_map(
self, stream: Iterable[Any], *args: Any, **kwargs: Any
) -> Iterable[tuple[int, Any]]:
try:
self.start(**kwargs)

assert self.input_queue is not None, "Input queue was not initialized"
assert self.output_queue is not None, "Output queue was not initialized"

pushed = 0
read = 0
for idx, item in enumerate(stream):
self.check_worker_health()
if pushed - read < self.queue_size:
try:
out_item = self.output_queue.get_nowait()
except Empty:
out_item = None
else:
try:
out_item = self.output_queue.get(timeout=processing_timeout)
except Empty as e:
self.join_or_terminate()
raise e

if out_item is not None:
if out_item == QueueSignals.error:
self.join_or_terminate()
raise RuntimeError("Thread unexpectedly terminated")
yield out_item
read += 1

self.input_queue.put((idx, item))
pushed += 1

for _ in range(self.num_workers):
self.input_queue.put(QueueSignals.stop)

while read < pushed:
self.check_worker_health()
out_item = self.output_queue.get(timeout=processing_timeout)
if out_item == QueueSignals.error:
self.join_or_terminate()
raise RuntimeError("Thread unexpectedly terminated")
yield out_item
read += 1
finally:
assert self.input_queue is not None, "Input queue is None"
assert self.output_queue is not None, "Output queue is None"
self.join()
self.input_queue.close()
self.output_queue.close()
if self.emergency_shutdown:
self.input_queue.cancel_join_thread()
self.output_queue.cancel_join_thread()
else:
self.input_queue.join_thread()
self.output_queue.join_thread()

def check_worker_health(self) -> None:
"""
Checks if any worker process has terminated unexpectedly
Expand Down

0 comments on commit 118dee8

Please sign in to comment.