Skip to content

Commit

Permalink
phas new numpy rng generator without setting/getting states manually
Browse files Browse the repository at this point in the history
  • Loading branch information
Saladino93 committed Sep 15, 2023
1 parent 456d8b9 commit 5c3b4d1
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions plancklens/sims/phas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,29 @@ def __init__(self, fname, idtype="INTEGER"):

self.con = sqlite3.connect(fname, timeout=3600., detect_types=sqlite3.PARSE_DECLTYPES)

def add(self, idx, state):
def add(self, idx):
idx = int(idx)
try:
assert (self.get(idx) is None)
keys_string = '_'.join(str(s) for s in state[1])
self.con.execute("INSERT INTO rngdb (id, type, pos, has_gauss, cached_gaussian, keys) VALUES (?,?,?,?,?,?)",
(idx, state[0], state[2], state[3], state[4], keys_string))
self.con.execute("INSERT INTO rngdb (id) VALUES (?)",
(idx))
self.con.commit()
except:
print("rng_db::rngdb add failed!")

def get(self, idx):
idx = int(idx)
cur = self.con.cursor()
cur.execute("SELECT type, pos, has_gauss, cached_gaussian, keys FROM rngdb WHERE id=?", (idx,))
cur.execute("SELECT id FROM rngdb WHERE id=?", (idx,))#probably won't be necessary anymore
data = cur.fetchone()
cur.close()
if data is None:
return None
else:
assert (len(data) == 5)
typ, pos, has_gauss, cached_gaussian, keys = data
assert (len(data) == 1)
id = data
keys = np.array([int(a) for a in keys.split('_')], dtype=np.uint32)
return [typ, keys, pos, has_gauss, cached_gaussian]
return [id]

def delete(self, idx):
idx = int(idx)
Expand All @@ -63,14 +62,11 @@ def delete(self, idx):


class sim_lib(object):
"""Generic class for simulations where only rng state is stored.
np.random rng states are stored in a sqlite3 database. By default the rng state function is np.random.get_state.
The rng_db class is tuned for this state fct, you may need to adapt this.
"""Generic class for simulations. We store the index idx, and then use np.random.RandomState
"""

def __init__(self, lib_dir, get_state_func=np.random.get_state, nsims_max=None):
def __init__(self, lib_dir, nsims_max=None):
if not os.path.exists(lib_dir) and mpi.rank == 0:
os.makedirs(lib_dir)
self.nmax = nsims_max
Expand All @@ -83,13 +79,12 @@ def __init__(self, lib_dir, get_state_func=np.random.get_state, nsims_max=None):
utils.hash_check(hsh, self.hashdict(), ignore=['lib_dir'], fn=fn_hash)

self._rng_db = rng_db(os.path.join(lib_dir, 'rngdb.db'), idtype='INTEGER')
self._get_rng_state = get_state_func

def get_sim(self, idx, **kwargs):
"""Returns sim number idx and caches random number generator state. """
if self.has_nmax(): assert idx < self.nmax
if not self.is_stored(idx):
self._rng_db.add(idx, self._get_rng_state())
self._rng_db.add(idx)
return self._build_sim_from_rng(self._rng_db.get(idx), **kwargs)

def has_nmax(self):
Expand Down Expand Up @@ -117,7 +112,7 @@ def hashdict(self):
"""Override this """
assert 0

def _build_sim_from_rng(self, rng_state):
def _build_sim_from_rng(self, idx):
"""Override this """
assert 0

Expand All @@ -127,9 +122,9 @@ def __init__(self, lib_dir, shape, **kwargs):
self.shape = shape
super(_pix_lib_phas, self).__init__(lib_dir, **kwargs)

def _build_sim_from_rng(self, rng_state, **kwargs):
np.random.set_state(rng_state)
return np.random.standard_normal(self.shape)
def _build_sim_from_rng(self, idx, **kwargs):
rng = np.random.RandomState(idx)
return rng.standard_normal(self.shape)

def hashdict(self):
return {'shape': self.shape}
Expand Down Expand Up @@ -159,9 +154,9 @@ def __init__(self, lib_dir,lmax, **kwargs):
self.lmax = lmax
super(_lib_phas, self).__init__(lib_dir, **kwargs)

def _build_sim_from_rng(self, rng_state, phas_only=False):
np.random.set_state(rng_state)
alm = (np.random.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * np.random.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
def _build_sim_from_rng(self, idx, phas_only=False):
rng = np.random.RandomState(idx)
alm = (rng.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * rng.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
if phas_only: return
m0 = hp.Alm.getidx(self.lmax, np.arange(self.lmax + 1,dtype = int),0)
alm[m0] = np.sqrt(2.) * alm[m0].real
Expand Down

0 comments on commit 5c3b4d1

Please sign in to comment.