Skip to content

Commit

Permalink
Merge pull request #50 from moeyensj/shared-memory-instance
Browse files Browse the repository at this point in the history
Store observations in shared memory specific to instance
  • Loading branch information
moeyensj authored Jul 26, 2023
2 parents f21e2fd + 220db8f commit 1d3342e
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 65 deletions.
169 changes: 104 additions & 65 deletions difi/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import multiprocessing as mp
import os
import warnings
from abc import ABC, abstractmethod
from itertools import combinations, repeat
Expand Down Expand Up @@ -556,7 +557,10 @@ def _store_as_shared_record_array(self, object_ids, obs_ids, times, ra, dec, nig
self._num_observations = observations_array.shape[0]
self._itemsize = observations_array.itemsize

shared_mem = shared_memory.SharedMemory("DIFI_ARRAY", create=True, size=observations_array.nbytes)
self._shared_memory_name = f"DIFI_ARRAY_{os.getpid()}"
shared_mem = shared_memory.SharedMemory(
self._shared_memory_name, create=True, size=observations_array.nbytes
)
shared_memory_array = np.ndarray(
observations_array.shape, dtype=observations_array.dtype, buffer=shared_mem.buf
)
Expand All @@ -569,10 +573,19 @@ def _store_as_shared_record_array(self, object_ids, obs_ids, times, ra, dec, nig
shared_mem.close()
return

@staticmethod
def _clear_shared_record_array():
shared_mem = shared_memory.SharedMemory("DIFI_ARRAY")
def _clear_shared_record_array(self):
"""
Clears the shared memory array for this instance of the metric.
Returns
-------
None
"""
shared_mem = shared_memory.SharedMemory(self._shared_memory_name)
shared_mem.unlink()
self._shared_memory_name = None
self._num_observations = 0
self._dtypes = None

def _run_object_worker(
self,
Expand Down Expand Up @@ -626,7 +639,7 @@ def _run_object_worker(
]

# Load the observations from shared memory
existing_shared_mem = shared_memory.SharedMemory(name="DIFI_ARRAY")
existing_shared_mem = shared_memory.SharedMemory(name=self._shared_memory_name)
observations = np.ndarray(
num_obs,
dtype=self._dtypes,
Expand Down Expand Up @@ -689,6 +702,7 @@ def run_by_object(
discovery_probability: float = 1.0,
ignore_after_discovery: bool = False,
num_jobs: Optional[int] = 1,
clear_on_failure: bool = True,
) -> List[pd.DataFrame]:
"""
Run the findability metric on the observations split by objects. For windows where there are many
Expand Down Expand Up @@ -718,6 +732,9 @@ def run_by_object(
num_jobs : int, optional
The number of jobs to run in parallel. If 1, then run in serial. If None, then use the number of
CPUs on the machine.
clear_on_failure : bool, optional
If a failure occurs and this is False, then the shared memory array will not be cleared.
If True, then the shared memory array will be cleared.
Returns
-------
Expand All @@ -738,39 +755,45 @@ def run_by_object(
# Split arrays by object
split_by_object_slices = self._split_by_object(objects)

# Store the observations in a global variable so that the worker functions can access them
self._store_as_shared_record_array(objects, obs_ids, times, ra, dec, nights)

findable_lists: List[List[Dict[str, Any]]] = []
if num_jobs is None or num_jobs > 1:
pool = mp.Pool(num_jobs)
findable_lists = pool.starmap(
self._run_object_worker,
zip(
split_by_object_slices,
repeat(windows),
repeat(discovery_opportunities),
repeat(discovery_probability),
repeat(ignore_after_discovery),
),
)
try:
# Store the observations in a global variable so that the worker functions can access them
self._store_as_shared_record_array(objects, obs_ids, times, ra, dec, nights)

findable_lists: List[List[Dict[str, Any]]] = []
if num_jobs is None or num_jobs > 1:
pool = mp.Pool(num_jobs)
findable_lists = pool.starmap(
self._run_object_worker,
zip(
split_by_object_slices,
repeat(windows),
repeat(discovery_opportunities),
repeat(discovery_probability),
repeat(ignore_after_discovery),
),
)

pool.close()
pool.join()
pool.close()
pool.join()

else:
for object_indices in split_by_object_slices:
findable_lists.append(
self._run_object_worker(
object_indices,
windows,
discovery_opportunities=discovery_opportunities,
discovery_probability=discovery_probability,
ignore_after_discovery=ignore_after_discovery,
else:
for object_indices in split_by_object_slices:
findable_lists.append(
self._run_object_worker(
object_indices,
windows,
discovery_opportunities=discovery_opportunities,
discovery_probability=discovery_probability,
ignore_after_discovery=ignore_after_discovery,
)
)
)

self._clear_shared_record_array()
self._clear_shared_record_array()

except Exception as e:
if clear_on_failure:
self._clear_shared_record_array()
raise e

findable_flattened = [item for sublist in findable_lists for item in sublist]

Expand Down Expand Up @@ -832,7 +855,7 @@ def _run_window_worker(
]

# Read observations from shared memory array
existing_shared_mem = shared_memory.SharedMemory(name="DIFI_ARRAY")
existing_shared_mem = shared_memory.SharedMemory(name=self._shared_memory_name)
observations = np.ndarray(
num_obs,
dtype=self._dtypes,
Expand Down Expand Up @@ -897,6 +920,7 @@ def run_by_window(
discovery_opportunities: bool = False,
discovery_probability: float = 1.0,
num_jobs: Optional[int] = 1,
clear_on_failure: bool = True,
) -> List[pd.DataFrame]:
"""
Run the findability metric on the observations split by windows where each window will
Expand All @@ -922,6 +946,9 @@ def run_by_window(
num_jobs : int, optional
The number of jobs to run in parallel. If 1, then run in serial. If None, then use the number of
CPUs on the machine.
clear_on_failure : bool, optional
If a failure occurs and this is False, then the shared memory array will not be cleared.
If True, then the shared memory array will be cleared.
Returns
-------
Expand All @@ -939,39 +966,45 @@ def run_by_window(
observations["night"].values,
)

# Store the observations in a global variable so that the worker functions can access them
self._store_as_shared_record_array(objects, obs_ids, times, ra, dec, nights)

# Find indices that split the observations into windows
split_by_window_slices = self._split_by_window(windows, nights)

findable_lists: List[List[Dict[str, Any]]] = []
if num_jobs is None or num_jobs > 1:
pool = mp.Pool(num_jobs)
findable_lists = pool.starmap(
self._run_window_worker,
zip(
split_by_window_slices,
range(len(windows)),
repeat(discovery_opportunities),
repeat(discovery_probability),
),
)
pool.close()
pool.join()
try:
# Store the observations in a global variable so that the worker functions can access them
self._store_as_shared_record_array(objects, obs_ids, times, ra, dec, nights)

# Find indices that split the observations into windows
split_by_window_slices = self._split_by_window(windows, nights)

findable_lists: List[List[Dict[str, Any]]] = []
if num_jobs is None or num_jobs > 1:
pool = mp.Pool(num_jobs)
findable_lists = pool.starmap(
self._run_window_worker,
zip(
split_by_window_slices,
range(len(windows)),
repeat(discovery_opportunities),
repeat(discovery_probability),
),
)
pool.close()
pool.join()

else:
for i, window_slice in enumerate(split_by_window_slices):
findable_lists.append(
self._run_window_worker(
window_slice,
i,
discovery_opportunities=discovery_opportunities,
discovery_probability=discovery_probability,
else:
for i, window_slice in enumerate(split_by_window_slices):
findable_lists.append(
self._run_window_worker(
window_slice,
i,
discovery_opportunities=discovery_opportunities,
discovery_probability=discovery_probability,
)
)
)

self._clear_shared_record_array()
self._clear_shared_record_array()

except Exception as e:
if clear_on_failure:
self._clear_shared_record_array()
raise e

findable_flattened = [item for sublist in findable_lists for item in sublist]

Expand All @@ -995,6 +1028,7 @@ def run(
by_object: bool = False,
ignore_after_discovery: bool = False,
num_jobs: Optional[int] = 1,
clear_on_failure: bool = True,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Run the findability metric on the observations.
Expand Down Expand Up @@ -1036,6 +1070,9 @@ def run(
num_jobs : int, optional
The number of jobs to run in parallel. If 1, then run in serial. If None, then use the number of
CPUs on the machine.
clear_on_failure : bool, optional
If a failure occurs and this is False, then the shared memory array will not be cleared.
If True, then the shared memory array will be cleared.
Returns
-------
Expand Down Expand Up @@ -1067,6 +1104,7 @@ def run(
discovery_probability=discovery_probability,
ignore_after_discovery=ignore_after_discovery,
num_jobs=num_jobs,
clear_on_failure=clear_on_failure,
)
else:
findable = self.run_by_window(
Expand All @@ -1075,6 +1113,7 @@ def run(
discovery_opportunities=discovery_opportunities,
discovery_probability=discovery_probability,
num_jobs=num_jobs,
clear_on_failure=clear_on_failure,
)

window_summary = self._create_window_summary(observations, windows, findable)
Expand Down
36 changes: 36 additions & 0 deletions difi/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pytest

Expand Down Expand Up @@ -486,3 +488,37 @@ def test_calcFindableMinObs_assertion(test_observations):
with pytest.raises(AssertionError):
metric = MinObsMetric()
metric.determine_object_findable(test_observations)


def test_FindabilityMetrics_shared_memory(test_observations):
# Check that the function stores the observations in shared memory under
# the correct name
metric = MinObsMetric()

# Extract the data from the test observations
object_ids = test_observations["object_id"].values
obs_ids = test_observations["obs_id"].values
time = test_observations["time"].values
ra = test_observations["ra"].values
dec = test_observations["dec"].values
night = test_observations["night"].values

# Store the observations in shared memory
metric._store_as_shared_record_array(object_ids, obs_ids, time, ra, dec, night)

# Check that the shared memory array has the correct name
assert metric._shared_memory_name == f"DIFI_ARRAY_{os.getpid()}"
assert metric._num_observations == len(test_observations)
assert metric._dtypes == [
("object_id", object_ids.dtype),
("obs_id", obs_ids.dtype),
("time", np.float64),
("ra", np.float64),
("dec", np.float64),
("night", np.int64),
]

metric._clear_shared_record_array()
assert metric._shared_memory_name is None
assert metric._num_observations == 0
assert metric._dtypes is None

0 comments on commit 1d3342e

Please sign in to comment.