Skip to content

Commit

Permalink
added single-variable grid search to optimise transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
tkchafin committed May 9, 2024
1 parent 6fd504c commit b6a3551
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 134 deletions.
6 changes: 1 addition & 5 deletions scripts/ensembleResistnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __init__(self):
["help", "out=", "in=", "network=", "reps=", "shp=",
"len_col=", "id_col=", "split_samples", "max_keep=",
"awsum=", "list=", "threads=", "edge_agg=", "varFile=",
"allShapes", "report_all", "noPlot", "only_best",
"report_all", "noPlot", "only_best",
"only_keep", "coords=", "seed="]
)

Expand Down Expand Up @@ -251,7 +251,6 @@ def set_default_values(self):
self.seed = None
self.GA_procs = 1
self.minimize = False
self.allShapes = False
self.report_all = False
self.plot = True
self.coords = None
Expand Down Expand Up @@ -329,8 +328,6 @@ def set_arguments(self, options):
)
elif opt in ("V", "varFile"):
self.varFile = arg
elif opt == "allShapes":
self.allShapes = True
else:
assert False, f"Unhandled option {opt!r}"

Expand Down Expand Up @@ -370,7 +367,6 @@ def display_help(self, message=None):
"-c, --id_col: Reach ID attribute (def=EDGE_ID)\n"
"-o, --out: Output file prefix (default=ensemble)\n"
"--report_all: Plot full outputs for all retained models\n"
"--allShapes: Allow inverse and reverse transformations\n"
"-V, --varfile: Optional file with variables provided like so:\n"
" var1 \t <Optional aggregator function>\n"
" var2 \t <Optional aggregator function>\n"
Expand Down
25 changes: 12 additions & 13 deletions scripts/runResistnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import os
import sys
import random
from datetime import datetime

Expand Down Expand Up @@ -66,17 +67,17 @@ def main():
verbose=True
)

# # Optionally optimise transformations for each parameter
# runner.optimise_transformations(
# fitmetric=params.fitmetric,
# threads=params.GA_procs,
# posWeight=params.posWeight,
# fixWeight=params.fixWeight,
# fixShape=params.fixShape,
# allShapes=params.allShapes,
# max_shape=params.max_shape,
# verbose=True
# )
# Optionally optimise transformations for each parameter
if params.gridSearch:
fixed_params = runner.optimise_univariate(
fitmetric=params.fitmetric,
threads=params.GA_procs,
max_shape=params.max_shape,
out=params.out,
plot=True,
verbose=True
)
# write fixed params to file

# Step 3: Run GA optimisation
runner.run_ga(
Expand All @@ -92,10 +93,8 @@ def main():
nFail=params.nfail,
popsize=params.popsize,
maxpopsize=params.maxpopsize,
posWeight=params.posWeight,
fixWeight=params.fixWeight,
fixShape=params.fixShape,
allShapes=params.allShapes,
min_weight=params.min_weight,
max_shape=params.max_shape,
max_hof_size=params.max_hof_size,
Expand Down
Loading

0 comments on commit b6a3551

Please sign in to comment.