Skip to content

Commit

Permalink
black + dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
jcblemai committed Dec 7, 2023
1 parent 14366ea commit 9bcdf11
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 155 deletions.
5 changes: 3 additions & 2 deletions batch/inference_job_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def launch_batch(
continuation_run_id,
)


seir_modifiers_scenarios = None
outcome_modifiers_scenarios = None
# here the config is a dict
Expand Down Expand Up @@ -712,7 +711,9 @@ def launch(self, job_name, config_file, seir_modifiers_scenarios, outcome_modifi
cur_env_vars = base_env_vars.copy()
cur_env_vars.append({"name": "FLEPI_SEIR_SCENARIOS", "value": s})
cur_env_vars.append({"name": "FLEPI_OUTCOME_SCENARIOS", "value": d})
cur_env_vars.append({"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"}) # TODO: get it from gempyor and makes it contains run_id also in scripts
cur_env_vars.append(
{"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"}
) # TODO: get it from gempyor and makes it contains run_id also in scripts
cur_env_vars.append({"name": "FLEPI_BLOCK_INDEX", "value": "1"})
cur_env_vars.append({"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"})
if not (self.restart_from_location is None):
Expand Down
7 changes: 3 additions & 4 deletions flepimop/gempyor_pkg/src/gempyor/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .compartments import compartments
from gempyor.utils import config


@click.group()
@click.option(
"-c",
Expand All @@ -17,11 +18,9 @@ def cli(config_file):
config.read(user=False)
config.set_file(config_file)

cli.add_command(compartments)

cli.add_command(compartments)


if __name__ == '__main__':
if __name__ == "__main__":
cli()


27 changes: 15 additions & 12 deletions flepimop/gempyor_pkg/src/gempyor/compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co

return rc

def toFile(self, compartments_file='compartments.parquet', transitions_file='transitions.parquet', write_parquet=True):
def toFile(
self, compartments_file="compartments.parquet", transitions_file="transitions.parquet", write_parquet=True
):
out_df = self.compartments.copy()
if write_parquet:
pa_df = pa.Table.from_pandas(out_df, preserve_index=False)
Expand Down Expand Up @@ -294,7 +296,7 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i
comp_idx = self.compartments[mask].index.values
if len(comp_idx) != 1:
raise ValueError(
f"The provided dictionary does not allow to isolate a compartment: {comp_dict} isolate {self.compartments[mask]} from options {self.compartments}. The get_comp_idx function was called by'{error_info}'."
f"The provided dictionary does not allow to isolate a compartment: {comp_dict} isolate {self.compartments[mask]} from options {self.compartments}. The get_comp_idx function was called by'{error_info}'."
)
return comp_idx[0]

Expand Down Expand Up @@ -493,7 +495,7 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names
# in this case we find the next array and set it to that size,
# TODO: instead of searching for the next array, better to just use the parameter shape.
if not isinstance(substituted_formulas[i], np.ndarray):
for k in range(len(substituted_formulas)):
for k in range(len(substituted_formulas)):
if isinstance(substituted_formulas[k], np.ndarray):
substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k])

Expand Down Expand Up @@ -650,40 +652,41 @@ def list_recursive_convert_to_string(thing):
return str(thing)



@click.group()
def compartments():
pass


# TODO: CLI arguments
@compartments.command()
def plot():
assert config["compartments"].exists()
assert config["seir"].exists()
assert config["compartments"].exists()
assert config["seir"].exists()
comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"])

# TODO: this should be a command like build compartments.
(
unique_strings,
transition_array,
proportion_array,
proportion_info,
) = comp.get_transition_array()

comp.plot(output_file="transition_graph", source_filters=[], destination_filters=[])

print("wrote file transition_graph")


@compartments.command()
def export():
assert config["compartments"].exists()
assert config["seir"].exists()
assert config["compartments"].exists()
assert config["seir"].exists()
comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"])
(
unique_strings,
transition_array,
proportion_array,
proportion_info,
) = comp.get_transition_array()
comp.toFile('compartments_file.csv', 'transitions_file.csv')
print("wrote files 'compartments_file.csv', 'transitions_file.csv' ")
comp.toFile("compartments_file.csv", "transitions_file.csv")
print("wrote files 'compartments_file.csv', 'transitions_file.csv' ")
18 changes: 0 additions & 18 deletions flepimop/gempyor_pkg/src/gempyor/outcomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,14 @@ def read_parameters_from_config(modinf: model_info.ModelInfo):
f"Places in seir input files does not correspond to subpops in outcome probability file {branching_file}"
)

# subclasses = [""]
# if modinf.outcomes_config["subclasses"].exists():
# subclasses = modinf.outcomes_config["subclasses"].get()

parameters = {}
for new_comp in outcomes_config:
if outcomes_config[new_comp]["source"].exists():
# for subclass in subclasses:
# class_name = new_comp + subclass
# parameters[class_name] = {}
parameters[new_comp] = {}
# Read the config for this compartement
src_name = outcomes_config[new_comp]["source"].get()
if isinstance(src_name, str):
# if src_name != "incidI":
# parameters[class_name]["source"] = src_name + subclass
# else:
# parameters[class_name]["source"] = src_name
parameters[new_comp]["source"] = src_name
# else:
# else:
# if subclasses != [""]:
# raise ValueError("Subclasses not compatible with outcomes from compartments ")
# elif ("incidence" in src_name.keys()) or ("prevalence" in src_name.keys()):
# parameters[class_name]["source"] = dict(src_name)

elif ("incidence" in src_name.keys()) or ("prevalence" in src_name.keys()):
parameters[new_comp]["source"] = dict(src_name)

Expand Down
65 changes: 35 additions & 30 deletions flepimop/gempyor_pkg/src/gempyor/seeding_ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _DataFrame2NumbaDict(df, amounts, setup) -> nb.typed.Dict:
n_seeding_ignored_before = 0
n_seeding_ignored_after = 0

#id_seed = 0
# id_seed = 0
for idx, (row_index, row) in enumerate(df.iterrows()):
if row["subpop"] not in setup.subpop_struct.subpop_names:
raise ValueError(
Expand All @@ -44,17 +44,21 @@ def _DataFrame2NumbaDict(df, amounts, setup) -> nb.typed.Dict:

if (row["date"].date() - setup.ti).days >= 0:
if (row["date"].date() - setup.ti).days < len(nb_seed_perday):

nb_seed_perday[(row["date"].date() - setup.ti).days] = (
nb_seed_perday[(row["date"].date() - setup.ti).days] + 1
)
source_dict = {grp_name: row[f"source_{grp_name}"] for grp_name in cmp_grp_names}
destination_dict = {grp_name: row[f"destination_{grp_name}"] for grp_name in cmp_grp_names}
seeding_dict["seeding_sources"][idx] = setup.compartments.get_comp_idx(source_dict, error_info = f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)")
seeding_dict["seeding_destinations"][idx] = setup.compartments.get_comp_idx(destination_dict, error_info = f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)")
seeding_dict["seeding_sources"][idx] = setup.compartments.get_comp_idx(
source_dict, error_info=f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)"
)
seeding_dict["seeding_destinations"][idx] = setup.compartments.get_comp_idx(
destination_dict,
error_info=f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)",
)
seeding_dict["seeding_subpops"][idx] = setup.subpop_struct.subpop_names.index(row["subpop"])
seeding_amounts[idx] = amounts[idx]
#id_seed+=1
# id_seed+=1
else:
n_seeding_ignored_after += 1
else:
Expand Down Expand Up @@ -303,14 +307,14 @@ def draw_seeding(self, sim_id: int, setup) -> nb.typed.Dict:
raise NotImplementedError(f"unknown seeding method [got: {method}]")

# Sorting by date is very important here for the seeding format necessary !!!!
#print(seeding.shape)
# print(seeding.shape)
seeding = seeding.sort_values(by="date", axis="index").reset_index()
#print(seeding)
mask = (seeding['date'].dt.date > setup.ti) & (seeding['date'].dt.date <= setup.tf)
# print(seeding)
mask = (seeding["date"].dt.date > setup.ti) & (seeding["date"].dt.date <= setup.tf)
seeding = seeding.loc[mask].reset_index()
#print(seeding.shape)
#print(seeding)
# print(seeding.shape)
# print(seeding)

# TODO: print.

amounts = np.zeros(len(seeding))
Expand All @@ -322,11 +326,10 @@ def draw_seeding(self, sim_id: int, setup) -> nb.typed.Dict:
elif method == "FolderDraw" or method == "FromFile":
amounts = seeding["amount"]


return _DataFrame2NumbaDict(df=seeding, amounts=amounts, setup=setup)

def load_seeding(self, sim_id: int, setup) -> nb.typed.Dict:
""" only difference with draw seeding is that the sim_id is now sim_id2load"""
"""only difference with draw seeding is that the sim_id is now sim_id2load"""
return self.draw_seeding(sim_id=sim_id, setup=setup)

def load_ic(self, sim_id: int, setup) -> nb.typed.Dict:
Expand All @@ -336,19 +339,21 @@ def load_ic(self, sim_id: int, setup) -> nb.typed.Dict:
def seeding_write(self, seeding, fname, extension):
raise NotImplementedError(f"It is not yet possible to write the seeding to a file")


class SimulationComponent:
def __init__(self, config: confuse.ConfigView):
raise NotImplementedError("This method should be overridden in subclasses.")

def load(self, sim_id: int, setup) -> np.ndarray:
raise NotImplementedError("This method should be overridden in subclasses.")

def draw(self, sim_id: int, setup) -> np.ndarray:
raise NotImplementedError("This method should be overridden in subclasses.")

def write_to_file(self, sim_id: int, setup):
raise NotImplementedError("This method should be overridden in subclasses.")



class Seeding(SimulationComponent):
def __init__(self, config: confuse.ConfigView):
self.seeding_config = config
Expand Down Expand Up @@ -393,14 +398,14 @@ def draw(self, sim_id: int, setup) -> nb.typed.Dict:
raise NotImplementedError(f"unknown seeding method [got: {method}]")

# Sorting by date is very important here for the seeding format necessary !!!!
#print(seeding.shape)
# print(seeding.shape)
seeding = seeding.sort_values(by="date", axis="index").reset_index()
#print(seeding)
mask = (seeding['date'].dt.date > setup.ti) & (seeding['date'].dt.date <= setup.tf)
# print(seeding)
mask = (seeding["date"].dt.date > setup.ti) & (seeding["date"].dt.date <= setup.tf)
seeding = seeding.loc[mask].reset_index()
#print(seeding.shape)
#print(seeding)
# print(seeding.shape)
# print(seeding)

# TODO: print.

amounts = np.zeros(len(seeding))
Expand All @@ -412,17 +417,17 @@ def draw(self, sim_id: int, setup) -> nb.typed.Dict:
elif method == "FolderDraw" or method == "FromFile":
amounts = seeding["amount"]


return _DataFrame2NumbaDict(df=seeding, amounts=amounts, setup=setup)

def load(self, sim_id: int, setup) -> nb.typed.Dict:
""" only difference with draw seeding is that the sim_id is now sim_id2load"""
"""only difference with draw seeding is that the sim_id is now sim_id2load"""
return self.draw(sim_id=sim_id, setup=setup)



class InitialConditions(SimulationComponent):
def __init__(self, config: confuse.ConfigView):
self.initial_conditions_config = config

def draw(self, sim_id: int, setup) -> np.ndarray:
method = "Default"
if self.initial_conditions_config is not None and "method" in self.initial_conditions_config.keys():
Expand Down Expand Up @@ -600,6 +605,6 @@ def draw(self, sim_id: int, setup) -> np.ndarray:
""" Ignoring the previous population mismatch errors because you added flag 'ignore_population_checks'. This is dangerous"""
)
return y0

def load(self, sim_id: int, setup) -> nb.typed.Dict:
return self.draw(sim_id=sim_id, setup=setup)
return self.draw(sim_id=sim_id, setup=setup)
2 changes: 1 addition & 1 deletion flepimop/gempyor_pkg/src/gempyor/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@
help="write parquet file output at end of simulation",
)
# @profile_options
#@profile()
# @profile()
def simulate(
config_file,
in_run_id,
Expand Down
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ def test_outcomes_read_write_hnpi2_custom_pname():
first_sim_index=1,
outcome_modifiers_scenario="Some",
stoch_traj_flag=False,
out_run_id=107,
)
out_run_id=107,
)

outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1)

Expand Down
23 changes: 12 additions & 11 deletions flepimop/gempyor_pkg/tests/seir/dev_new_test0.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def test_parameters_from_timeserie_file():
# if True:
# if True:
config.clear()
config.read(user=False)
config.set_file(f"{DATA_DIR}/config_compartmental_model_format_with_covariates.yml")
Expand All @@ -32,19 +32,20 @@ def test_parameters_from_timeserie_file():
outcome_scenario="high_death_rate",
stoch_traj_flag=False,
)

p = parameters.Parameters(
parameter_config=config["seir"]["parameters"],
ti=config["start_date"].as_date(),
tf=config["end_date"].as_date(),
nodenames=inference_simulator.s.spatset.nodenames,
config_version="v3")

#p = inference_simulator.s.parameters
parameter_config=config["seir"]["parameters"],
ti=config["start_date"].as_date(),
tf=config["end_date"].as_date(),
nodenames=inference_simulator.s.spatset.nodenames,
config_version="v3",
)

# p = inference_simulator.s.parameters
p_draw = p.parameters_quick_draw(n_days=inference_simulator.s.n_days, nnodes=inference_simulator.s.nnodes)

p_df = p.getParameterDF(p_draw)["parameter"]

for pn in p.pnames:
if pn == "R0s":
assert pn not in p_df
Expand All @@ -54,7 +55,7 @@ def test_parameters_from_timeserie_file():
initial_df = read_df("data/r0s_ts.csv").set_index("date")

assert (p_draw[p.pnames2pindex["R0s"]] == initial_df.values).all()

### test what happen when the order of geoids is not respected (expected: reput them in order)

### test what happens with incomplete data (expected: fail)
Expand Down
Loading

0 comments on commit 9bcdf11

Please sign in to comment.