Skip to content

Commit

Permalink
updates, still have to include spawn for parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
Saladino93 committed Sep 19, 2023
1 parent 5ed08e0 commit e384dea
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions plancklens/sims/phas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,30 @@ def __init__(self, fname, idtype="INTEGER"):

self.con = sqlite3.connect(fname, timeout=3600., detect_types=sqlite3.PARSE_DECLTYPES)

def add(self, idx):
def add(self, idx, state):
idx = int(idx)
try:
assert (self.get(idx) is None)
self.con.execute("INSERT INTO rngdb (id) VALUES (?)",
(idx))
keys_string = '_'.join(str(s) for s in state[1])
self.con.execute("INSERT INTO rngdb (id, type, pos, has_gauss, cached_gaussian, keys) VALUES (?,?,?,?,?,?)",
(idx, state[0], state[2], state[3], state[4], keys_string))
self.con.commit()
except:
print("rng_db::rngdb add failed!")

def get(self, idx):
idx = int(idx)
cur = self.con.cursor()
cur.execute("SELECT id FROM rngdb WHERE id=?", (idx,))#probably won't be necessary anymore
cur.execute("SELECT type, pos, has_gauss, cached_gaussian, keys FROM rngdb WHERE id=?", (idx,))
data = cur.fetchone()
cur.close()
if data is None:
return None
else:
assert (len(data) == 1)
id = data
assert (len(data) == 5)
typ, pos, has_gauss, cached_gaussian, keys = data
keys = np.array([int(a) for a in keys.split('_')], dtype=np.uint32)
return [id]
return [typ, keys, pos, has_gauss, cached_gaussian]

def delete(self, idx):
idx = int(idx)
Expand All @@ -62,11 +63,14 @@ def delete(self, idx):


class sim_lib(object):
"""Generic class for simulations. We store the index idx, and then use np.random.RandomState
"""Generic class for simulations where only rng state is stored.
np.random rng states are stored in a sqlite3 database. By default the rng state function is np.random.get_state.
The rng_db class is tuned for this state fct, you may need to adapt this.
"""

def __init__(self, lib_dir, nsims_max=None):
def __init__(self, lib_dir, get_state_func=np.random.get_state, nsims_max=None):
if not os.path.exists(lib_dir) and mpi.rank == 0:
os.makedirs(lib_dir)
self.nmax = nsims_max
Expand All @@ -79,12 +83,27 @@ def __init__(self, lib_dir, nsims_max=None):
utils.hash_check(hsh, self.hashdict(), ignore=['lib_dir'], fn=fn_hash)

self._rng_db = rng_db(os.path.join(lib_dir, 'rngdb.db'), idtype='INTEGER')
self._get_rng_state = get_state_func

@staticmethod
def get_state(idx):
"""Returns a random number generator state from a seed. """
#sg = np.random.SeedSequence(idx)
#mt19937 = np.random.MT19937(sg)
#rs = np.random.RandomState(mt19937)
rs = np.random.Generator(np.random.MT19937())
dictionary = rs.__getstate__()
l = [dictionary[k] for k in dictionary.keys()]
return [l[0], l[1]['key'], l[1]['pos'], 0, 0.0]
#return rs.get_state()

def get_sim(self, idx, **kwargs):
"""Returns sim number idx and caches random number generator state. """
if self.has_nmax(): assert idx < self.nmax

if not self.is_stored(idx):
self._rng_db.add(idx)
#self._rng_db.add(idx, self._get_rng_state())
self._rng_db.add(idx, self.get_state(idx))
return self._build_sim_from_rng(self._rng_db.get(idx), **kwargs)

def has_nmax(self):
Expand Down Expand Up @@ -112,7 +131,7 @@ def hashdict(self):
"""Override this """
assert 0

def _build_sim_from_rng(self, idx):
def _build_sim_from_rng(self, rng_state):
"""Override this """
assert 0

Expand All @@ -122,9 +141,9 @@ def __init__(self, lib_dir, shape, **kwargs):
self.shape = shape
super(_pix_lib_phas, self).__init__(lib_dir, **kwargs)

def _build_sim_from_rng(self, idx, **kwargs):
rng = np.random.RandomState(idx)
return rng.standard_normal(self.shape)
def _build_sim_from_rng(self, rng_state, **kwargs):
np.random.set_state(rng_state)
return np.random.standard_normal(self.shape)

def hashdict(self):
return {'shape': self.shape}
Expand Down Expand Up @@ -154,9 +173,9 @@ def __init__(self, lib_dir,lmax, **kwargs):
self.lmax = lmax
super(_lib_phas, self).__init__(lib_dir, **kwargs)

def _build_sim_from_rng(self, idx, phas_only=False):
rng = np.random.RandomState(idx)
alm = (rng.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * rng.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
def _build_sim_from_rng(self, rng_state, phas_only=False):
np.random.set_state(rng_state)
alm = (np.random.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * np.random.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
if phas_only: return
m0 = hp.Alm.getidx(self.lmax, np.arange(self.lmax + 1,dtype = int),0)
alm[m0] = np.sqrt(2.) * alm[m0].real
Expand Down Expand Up @@ -188,3 +207,4 @@ def get_sim(self, idx, idf=None, phas_only=False):

def hashdict(self):
return {'nfields': self.nfields, 'lmax':self.lmax}

0 comments on commit e384dea

Please sign in to comment.