Skip to content

Commit

Permalink
Updated opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Hjorthmedh committed Sep 7, 2023
1 parent ecaab37 commit 0a5b12a
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 8,512 deletions.
115 changes: 62 additions & 53 deletions examples/notebooks/optimise_prune/OptimisePrune.ipynb

Large diffs are not rendered by default.

8,389 changes: 170 additions & 8,219 deletions examples/notebooks/optimise_prune/run_all_fs_spn.ipynb

Large diffs are not rendered by default.

286 changes: 47 additions & 239 deletions snudda/optimise/optimise_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def prune_synapses(self, pre_type, post_type, con_type, pruning_parameters, outp
self.prune.connectivity_distributions = dict([])
self.prune.connectivity_distributions[pre_type_id, post_type_id, synapse_type_id] = (pruning, pruning_other)

# Update the config file in self.prune so the data gets written to network_synapses

assert len(self.merge_files_syn) == 1, f"merge_Files_syn should be a list with one file only"

# print(f"Writing to {output_file} (*)")
Expand Down Expand Up @@ -239,165 +241,22 @@ def evaluate_fitness(self, pre_type, post_type, output_file, experimental_data,
return error

@staticmethod
def helper_func1(x, *args):

pruning_parameters = dict()

if args is not None:
optimisation_info = args[0]

if "extra_pruning_parameters" in optimisation_info:
pruning_parameters |= optimisation_info["extra_pruning_parameters"]
else:
raise ValueError(f"No optimisation_info passed.")

pruning_parameters["f1"] = x[0]

if "output_file" in optimisation_info:
output_file = optimisation_info["output_file"]
else:
output_file = os.path.join(optimisation_info["network_path"], "temp", f"network-synapses-{uuid.uuid4()}.hdf5")
# print(f"Output file {output_file}")

# This trick allows us to reuse the same OptimisePruning object, will be faster
op = OptimisePruning.get_op(optimisation_info)

op.prune_synapses(pre_type=optimisation_info["pre_type"],
post_type=optimisation_info["post_type"],
con_type=optimisation_info["con_type"],
pruning_parameters=pruning_parameters,
output_file=output_file)

fitness = op.evaluate_fitness(pre_type=optimisation_info["pre_type"],
post_type=optimisation_info["post_type"],
output_file=output_file,
experimental_data=optimisation_info["exp_data"],
avg_num_synapses_per_pair=optimisation_info["avg_num_synapses_per_pair"])

# print(f"Evaluating f1 = {x[0]}, fitness: {fitness}\n{output_file}\n")
# print(f"Fitness: {fitness}")

OptimisePruning.report_fitness(fitness)

if "output_file" not in optimisation_info:
os.remove(output_file)

return fitness

@staticmethod
def helper_func2(x, *args):

pruning_parameters = dict()

if args is not None:
optimisation_info = args[0]

if "extra_pruning_parameters" in optimisation_info:
pruning_parameters |= optimisation_info["extra_pruning_parameters"]
else:
raise ValueError(f"No optimisation_info passed.")

pruning_parameters["f1"] = x[0]
pruning_parameters["mu2"] = x[1]

if "output_file" in optimisation_info:
output_file = optimisation_info["output_file"]
else:
output_file = os.path.join(optimisation_info["network_path"], "temp", f"network-synapses-{uuid.uuid4()}.hdf5")
# print(f"Output file {output_file}")

# This trick allows us to reuse the same OptimisePruning object, will be faster
op = OptimisePruning.get_op(optimisation_info)

op.prune_synapses(pre_type=optimisation_info["pre_type"],
post_type=optimisation_info["post_type"],
con_type=optimisation_info["con_type"],
pruning_parameters=pruning_parameters,
output_file=output_file)

fitness = op.evaluate_fitness(pre_type=optimisation_info["pre_type"],
post_type=optimisation_info["post_type"],
output_file=output_file,
experimental_data=optimisation_info["exp_data"],
avg_num_synapses_per_pair=optimisation_info["avg_num_synapses_per_pair"])

# print(f"Evaluating f1 = {x[0]}, mu2 = {x[1]}, fitness: {fitness}\n{output_file}\n")
# print(f"Fitness: {fitness}")

OptimisePruning.report_fitness(fitness)

if "output_file" not in optimisation_info:
os.remove(output_file)

return fitness

@staticmethod
def helper_func3(x, *args):
def opt_helper_func(x, *args):

pruning_parameters = dict()

if args is not None:
optimisation_info = args[0]

if "extra_pruning_parameters" in optimisation_info:
pruning_parameters |= optimisation_info["extra_pruning_parameters"]
else:
raise ValueError(f"No optimisation_info passed.")

pruning_parameters["f1"] = x[0]
pruning_parameters["mu2"] = x[1]
pruning_parameters["a3"] = x[2]

if "output_file" in optimisation_info:
output_file = optimisation_info["output_file"]
# args must be (optimisation_info)
if args is None:
raise ValueError("args must contain optimisation_info")
else:
output_file = os.path.join(optimisation_info["network_path"], "temp", f"network-synapses-{uuid.uuid4()}.hdf5")
# print(f"Output file {output_file}")

# This trick allows us to reuse the same OptimisePruning object, will be faster
op = OptimisePruning.get_op(optimisation_info)

op.prune_synapses(pre_type=optimisation_info["pre_type"],
post_type=optimisation_info["post_type"],
con_type=optimisation_info["con_type"],
pruning_parameters=pruning_parameters,
output_file=output_file)

fitness = op.evaluate_fitness(pre_type=optimisation_info["pre_type"],
post_type=optimisation_info["post_type"],
output_file=output_file,
experimental_data=optimisation_info["exp_data"],
avg_num_synapses_per_pair=optimisation_info["avg_num_synapses_per_pair"])

# print(f"Evaluating f1 = {x[0]}, mu2 = {x[1]}, a3 = {x[2]}, fitness: {fitness}\n{output_file}\n")
# print(f"Fitness: {fitness}")

OptimisePruning.report_fitness(fitness)

if "output_file" not in optimisation_info:
os.remove(output_file)

return fitness

@staticmethod
def helper_func4(x, *args):

# Includes softmax

pruning_parameters = dict()

if args is not None:
optimisation_info = args[0]
param_names = optimisation_info["param_names"]

if "extra_pruning_parameters" in optimisation_info:
pruning_parameters |= optimisation_info["extra_pruning_parameters"]
else:
raise ValueError(f"No optimisation_info passed.")

pruning_parameters["f1"] = x[0]
pruning_parameters["softMax"] = x[1]
pruning_parameters["mu2"] = x[2]
pruning_parameters["a3"] = x[3]
for p_name, p_value in zip(param_names, x):
pruning_parameters[p_name] = p_value

if "output_file" in optimisation_info:
output_file = optimisation_info["output_file"]
Expand All @@ -420,7 +279,7 @@ def helper_func4(x, *args):
experimental_data=optimisation_info["exp_data"],
avg_num_synapses_per_pair=optimisation_info["avg_num_synapses_per_pair"])

# print(f"Evaluating f1 = {x[0]}, SM = {x[1]}, mu2 = {x[2]}, a3 = {x[3]}, fitness: {fitness}\n{output_file}\n")
# print(f"Evaluating f1 = {x[0]}, fitness: {fitness}\n{output_file}\n")
# print(f"Fitness: {fitness}")

OptimisePruning.report_fitness(fitness)
Expand Down Expand Up @@ -460,8 +319,9 @@ def get_op(optimisation_info):

def optimize(self, pre_type, post_type, con_type,
experimental_data,
param_names, param_bounds,
extra_pruning_parameters, avg_num_synapses_per_pair=None,
workers=1, maxiter=50, tol=0.001, pop_size=None, num_params=4):
workers=1, maxiter=50, tol=0.001, pop_size=None):

start = timeit.default_timer()

Expand All @@ -470,6 +330,20 @@ def optimize(self, pre_type, post_type, con_type,
if pop_size is None:
pop_size = self.pop_size

if param_bounds == "default":
param_bounds = []
for p_name in param_names:
if p_name == "f1":
param_bounds.append((0, 1))
elif p_name == "softMax":
param_bounds.append((0, 20))
elif p_name == "mu2":
param_bounds.append((0, 5))
elif p_name == "a3":
param_bounds.append((0, 1))
else:
raise ValueError(f"No default parameter bounds for {p_name} (f1, softMax, mu2, a3)")

self.optimisation_info["pre_type"] = pre_type
self.optimisation_info["post_type"] = post_type
self.optimisation_info["con_type"] = con_type
Expand All @@ -478,64 +352,18 @@ def optimize(self, pre_type, post_type, con_type,
self.optimisation_info["extra_pruning_parameters"] = extra_pruning_parameters
self.optimisation_info["ctr"] = 0
self.optimisation_info["network_path"] = self.network_path
self.optimisation_info["param_names"] = param_names
self.optimisation_info["param_bounds"] = param_bounds

optimisation_info = self.optimisation_info

if num_params == 4:
# With softmax
bounds4 = [(0, 1), (0, 20), (0, 5), (0, 1)]
res = differential_evolution(func=OptimisePruning.helper_func4, args=(optimisation_info, ),
bounds=bounds4, workers=workers, maxiter=maxiter, tol=tol,
popsize=pop_size)

# Rerun the best parameters, and keep data as network-synapses.hdf5
optimisation_info["output_file"] = os.path.join(self.network_path, "network-synapses.hdf5")
OptimisePruning.helper_func4(res.x, optimisation_info)

elif num_params == 3:
# Without softmax
bounds3 = [(0, 1), (0, 5), (0, 1)]
res = differential_evolution(func=OptimisePruning.helper_func3, args=(optimisation_info, ),
bounds=bounds3, workers=workers, maxiter=maxiter, tol=tol,
popsize=self.pop_size)

optimisation_info["output_file"] = os.path.join(self.network_path, "network-synapses.hdf5")
OptimisePruning.helper_func3(res.x, optimisation_info)

elif num_params == 2:

# Without softmax
bounds3 = [(0, 1), (0, 5)]
res = differential_evolution(func=OptimisePruning.helper_func2, args=(optimisation_info, ),
bounds=bounds3, workers=workers, maxiter=maxiter, tol=tol,
popsize=self.pop_size)

optimisation_info["output_file"] = os.path.join(self.network_path, "network-synapses.hdf5")
OptimisePruning.helper_func2(res.x, optimisation_info)

elif num_params == 1:

# try:
# OptimisePruning.helper_func1([0.5], optimisation_info)
# except:
# import traceback
# print(traceback.format_exc())
# import pdb
# pdb.set_trace()

# Without softmax
bounds3 = [(0, 1)]
res = differential_evolution(func=OptimisePruning.helper_func1, args=(optimisation_info, ),
bounds=bounds3, workers=workers, maxiter=maxiter, tol=tol,
popsize=self.pop_size)

optimisation_info["output_file"] = os.path.join(self.network_path, "network-synapses.hdf5")
OptimisePruning.helper_func1(res.x, optimisation_info)
res = differential_evolution(func=OptimisePruning.opt_helper_func, args=(optimisation_info,),
bounds=optimisation_info["param_bounds"], workers=workers,
maxiter=maxiter, tol=tol, popsize=pop_size)

else:
raise ValueError(f"num_params = {num_params} must be 2,3, or 4")

# res = differential_evolution(func=self.helper_func, bounds=bounds, workers=workers)
# Rerun the best parameters, and keep data as network-synapses.hdf5
optimisation_info["output_file"] = os.path.join(self.network_path, "network-synapses.hdf5")
OptimisePruning.opt_helper_func(res.x, optimisation_info)

duration = timeit.default_timer() - start
self.log_file.write(f"Duration: {duration} s\n")
Expand All @@ -545,43 +373,28 @@ def optimize(self, pre_type, post_type, con_type,

return res

def get_parameters(self, res):

n_params = len(res.x)

params = dict()

if "extra_pruning_parameters" in self.optimisation_info:
params |= self.optimisation_info["extra_pruning_parameters"]

params["f1"] = res.x[0]
@staticmethod
def get_parameters(res, optimisation_info):

if n_params > 1:
params["mu2"] = res.x[1]
else:
params["mu2"] = None
param_names = optimisation_info["param_names"]

if n_params > 2:
params["a3"] = res.x[2]
else:
params["a3"] = None

if n_params > 3:
params["softMax"] = res.x[3]
else:
params["softMax"] = None
pruning_parameters = dict()
if "extra_pruning_parameters" in optimisation_info:
pruning_parameters |= optimisation_info["extra_pruning_parameters"]

if n_params > 4:
raise ValueError("Too many parameters encountered")
for p_name, p_value in zip(param_names, res.x):
pruning_parameters[p_name] = p_value

return params
return pruning_parameters

def export_json(self, file_name, res, append=False):

pre_type = self.optimisation_info["pre_type"]
post_type = self.optimisation_info["post_type"]
connection_type = self.optimisation_info["con_type"]

param_names = self.optimisation_info["param_names"]

if append and os.path.isfile(file_name):
with open(file_name, "r") as f:
config_data = json.load(f)
Expand All @@ -598,16 +411,11 @@ def export_json(self, file_name, res, append=False):
if connection_type not in con_data:
con_data[connection_type] = dict()

if "pruning" not in con_data[connection_type]:
con_data[connection_type]["pruning"] = dict()

con_data[connection_type]["pruning"] = self.get_parameters(res)
con_data[connection_type]["pruning"] = self.get_parameters(res=res, optimisation_info=self.optimisation_info)
if "pruningOther" in con_data[connection_type]:
del con_data[connection_type]["pruningOther"]

config_data["Connectivity"][f"{pre_type},{post_type}"] = con_data

with open(file_name, "w") as f:
json.dump(config_data, f, cls=NumpyEncoder, indent=4)

# WORKING ON WRITING CONNECTION DATA TO JSON FILE AUTOMATICALLY, for easier import into SNUDDA
json.dump(config_data, f, cls=NumpyEncoder, indent=4)
Loading

0 comments on commit 0a5b12a

Please sign in to comment.