-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added new process handler that controls the parallelization mode glob…
…ally: multiprocessing, ray or single_thread
- Loading branch information
1 parent
f815de7
commit 5477b93
Showing
6 changed files
with
126 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""Main init for package.""" | ||
from symmer.process_handler import process | ||
from symmer.operators import PauliwordOp, QuantumState | ||
from symmer.projection import QubitTapering, ContextualSubspace, QubitSubspaceManager |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,119 @@ | ||
import multiprocessing as mp | ||
import ray | ||
import os | ||
import numpy as np | ||
from ray import remote, put, get | ||
from multiprocessing import Process, Queue | ||
|
||
class ProcessHandler: | ||
|
||
method = 'mp' | ||
method = 'mp' | ||
verbose = False | ||
|
||
def __init__(self): | ||
self.n_logical_cores = mp.cpu_count() | ||
self.n_logical_cores = os.cpu_count() | ||
|
||
def _process_ray(self, func): | ||
def worker(iter, shared=None): | ||
return func(iter, shared) | ||
return worker | ||
|
||
def _process_mp(self, func): | ||
def worker(iter, shared=None): | ||
def prepare_chunks(self, iter): | ||
""" split a list into smaller sized chunks | ||
""" | ||
iter = list(iter) | ||
self.n_chunks = min(len(iter), self.n_logical_cores) | ||
chunk_size = int(np.ceil(len(iter)/self.n_chunks)) | ||
indices = np.append(np.arange(self.n_chunks)*chunk_size, None) | ||
for i,j in zip(indices[:-1], indices[1:]): | ||
yield iter[i:j] | ||
|
||
def _process_ray(self, func, iter, shared): | ||
""" Helper function for ray processing | ||
""" | ||
if self.verbose: | ||
print(f'*** executing in ray mode ***') | ||
# duplicate func with ray.remote wrapper : | ||
@remote(num_cpus=self.n_logical_cores, | ||
# runtime_env={ | ||
# "env_vars": { | ||
# "NUMBA_NUM_THREADS": os.getenv("NUMBA_NUM_THREADS"), | ||
# "OMP_NUM_THREADS": os.getenv("NUMBA_NUM_THREADS"), | ||
# "NUMEXPR_MAX_THREADS": str(self.n_logical_cores) | ||
# } | ||
# } | ||
) | ||
def _func(iter, shared): | ||
return func(iter, shared) | ||
return worker | ||
# place into shared memory: | ||
shared_obj = put(shared) | ||
# split iterable into smaller chunks and parallelize remote instances: | ||
results = get( | ||
[ | ||
_func.remote(chunk, shared_obj) | ||
for chunk in self.prepare_chunks(iter) | ||
] | ||
) | ||
# flatten the list and return: | ||
return [a for b in results for a in b] | ||
|
||
def _process_mp(self, func, iter, shared): | ||
""" Helper function for multiprocessing | ||
""" | ||
if self.verbose: | ||
print(f'*** executing in multiprocessing mode ***') | ||
# wrapper function for putting results into queue | ||
def _func(iter, shared, _queue=None): | ||
data_out = func(iter, shared) | ||
_queue.put(data_out) | ||
|
||
chunks = list(self.prepare_chunks(iter)) | ||
procs = [] # for storing processes | ||
queue = Queue(self.n_chunks) # storage of data from processes | ||
for chunk in chunks: | ||
proc = Process(target=_func, args=(chunk, shared, queue)) | ||
procs.append(proc) | ||
proc.start() | ||
# retrieve data from the queue | ||
data = [] | ||
for _ in range(self.n_chunks): | ||
data += queue.get() | ||
# complete the processes | ||
for proc in procs: | ||
proc.join() | ||
return data | ||
|
||
def _process_single(self, func, iter, shared): | ||
""" Helper function for single threading | ||
""" | ||
if self.verbose: | ||
print(f'*** executing in single-threaded mode ***') | ||
return func(iter, shared) | ||
|
||
def parallelize(self, func): | ||
if self.method == 'mp': | ||
return self._process_mp(func) | ||
elif self.method == 'ray': | ||
return self._process_ray(func) | ||
|
||
if __name__ == '__main__': | ||
PH = ProcessHandler() | ||
|
||
@PH.parallelize | ||
def add_n(l, n): | ||
return [i+n for i in l] | ||
def wrapper(iter, shared): | ||
|
||
_func = lambda iter,shared: [func(i, shared) for i in iter] | ||
|
||
if self.method == 'mp': | ||
return self._process_mp(_func, iter, shared) | ||
elif self.method == 'ray': | ||
return self._process_ray(_func, iter, shared) | ||
elif self.method == 'single_thread': | ||
return self._process_single(_func, iter, shared) | ||
else: | ||
raise ValueError(f'Invalid processing method {self.method}, must be ray, mp or single_thread.') | ||
|
||
return wrapper | ||
|
||
process = ProcessHandler() | ||
|
||
print(add_n([1,2,3,4,5,6,7], 3)) | ||
if __name__ == '__main__': | ||
|
||
@process.parallelize | ||
def multiply_list(iter, shared): | ||
return [i*shared for i in iter] | ||
|
||
l = list(range(100)) | ||
|
||
process.method = 'single_thread' | ||
print(multiply_list(l,2)) | ||
process.method = 'mp' | ||
print(multiply_list(l,2)) | ||
process.method = 'ray' | ||
print(multiply_list(l,2)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters