diff --git a/config/config.json b/config/config.json new file mode 100644 index 0000000..0e9424d --- /dev/null +++ b/config/config.json @@ -0,0 +1,8 @@ +{ + "model_type": "MTLSD", + "iterations": 100000, + "warmup": 100000, + "raw_file": "path/to/zarr/or/n5", + "voxel_size": 33, + "python_script_path": "path/to/python_script.py" +} diff --git a/optoseg/README.md b/optoseg/README.md deleted file mode 100644 index e45b007..0000000 --- a/optoseg/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# optoseg -Parallelizing training for hyper-parameter tuning and batch jobs. diff --git a/optoseg/go.mod b/optoseg/go.mod deleted file mode 100644 index d3e9874..0000000 --- a/optoseg/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module optoseg - -go 1.19 diff --git a/optoseg/src/jobs/job.go b/optoseg/src/jobs/job.go deleted file mode 100644 index fad5338..0000000 --- a/optoseg/src/jobs/job.go +++ /dev/null @@ -1,39 +0,0 @@ -package jobs - -import ( - "fmt" - "os/exec" - "sync" -) - -func TrainModel(modelType string, scriptPath string, args ...string) { - cmd := exec.Command(scriptPath, args...) - output, err := cmd.CombinedOutput() - if err != nil { - fmt.Printf("Error running %s: %v\nOutput:\n%s\n", scriptPath, err, output) - return - } - fmt.Printf("Output from %s:\n%s\n", scriptPath, output) -} - -func RunTrainingInParallel(trainingParams []struct { - modelType string - scriptPath string - args []string -}) { - var wg sync.WaitGroup - - for _, params := range trainingParams { - wg.Add(1) - go func(p struct { - modelType string - scriptPath string - args []string - }) { - defer wg.Done() - TrainModel(p.modelType, p.scriptPath, p.args...) - }(params) - } - - wg.Wait() -} diff --git a/optoseg/src/main.go b/optoseg/src/main.go deleted file mode 100644 index 9b63749..0000000 --- a/optoseg/src/main.go +++ /dev/null @@ -1,17 +0,0 @@ -package main - -import "autoseg" - -func main() { - trainingParams := []struct { - modelType string - scriptPath string - args []string - }{ // TODO: add model type arg - {"MTLSD", "../batch_run.py", []string{"--iterations", "100000", "--raw_file", "path/to/zarr/or/n5", "--voxel_size", "33"}}, - {"ACLSD", "../batch_run.py", []string{"--iterations", "100000", "--raw_file", "path/to/zarr/or/n5", "--warmup", "100000"}}, - {"STELARR", "../batch_run.py", []string{"--iterations", "100000", "--raw_file", "path/to/zarr/or/n5", "--warmup", "100000"}}, - } - - autoseg.RunTrainingInParallel(trainingParams) -} \ No newline at end of file diff --git a/segmonitor/Cargo.toml b/segmonitor/Cargo.toml new file mode 100644 index 0000000..01e8511 --- /dev/null +++ b/segmonitor/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "segmonitor" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +mongodb = "2.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/segmonitor/src/lib.rs b/segmonitor/src/lib.rs new file mode 100644 index 0000000..2506108 --- /dev/null +++ b/segmonitor/src/lib.rs @@ -0,0 +1,55 @@ +use std::process::Command; +use serde::{Deserialize, Serialize}; +use mongodb::{Client, options::ClientOptions}; + +#[derive(Debug, Deserialize, Serialize)] +struct Config { + model_type: String, + iterations: i32, + warmup: i32, + raw_file: String, + voxel_size: i32, + python_script_path: String, +} + +pub mod segmonitor { + pub fn train_model_from_config(config_path: &str) { + let config = load_config(config_path); + + println!("Training model: {}", config.model_type); + println!("Iterations: {}", config.iterations); + println!("Warmup: {}", config.warmup); + println!("Raw file: {}", config.raw_file); + println!("Voxel size: {}", config.voxel_size); + + call_python_train(&config.python_script_path); + + save_to_mongodb(&config); + } + + fn load_config(config_path: &str) -> Config { + // Load configuration from JSON file + let config_str = std::fs::read_to_string(config_path).expect("Error reading config file"); + serde_json::from_str(&config_str).expect("Error parsing JSON") + } + + fn call_python_train(script_path: &str) { + let output = Command::new("python") + .arg(script_path) + .output() + .expect("Failed to execute training"); + + if output.status.success() { + println!("Traning executed successfully!"); + } else { + println!("Error executing training:\n{}", String::from_utf8_lossy(&output.stderr)); + } + } + + fn save_to_mongodb(config: &Config) { + println!("Saving metrics to MongoDB..."); + let client_options = ClientOptions::parse("mongodb://localhost:27017").unwrap(); + let client = Client::with_options(client_options).unwrap(); + // TODO: dd MongoDB insertion logic here + } +} diff --git a/segmonitor/src/main.rs b/segmonitor/src/main.rs new file mode 100644 index 0000000..381149c --- /dev/null +++ b/segmonitor/src/main.rs @@ -0,0 +1,6 @@ +use segmonitor; + +fn main() { + let config_path = "path/to/config.json"; + segmonitor::train_model_from_config(config_path); +} diff --git a/src/autoseg/eval/eval_db.py b/src/autoseg/eval/eval_db.py index e6376d8..a5d7b58 100644 --- a/src/autoseg/eval/eval_db.py +++ b/src/autoseg/eval/eval_db.py @@ -1,6 +1,7 @@ import sqlite3 import json + class Database: """ Simple SQLite Database Wrapper for Storing and Retrieving Scores. @@ -9,17 +10,17 @@ class Database: Each score entry is associated with a network, checkpoint, threshold, and a dictionary of scores. Args: - db_name (str): + db_name (str): The name of the SQLite database file. - table_name (str): + table_name (str): The name of the table within the database (default is 'scores_table'). Attributes: - conn (sqlite3.Connection): + conn (sqlite3.Connection): The SQLite database connection. - cursor (sqlite3.Cursor): + cursor (sqlite3.Cursor): The SQLite database cursor. - table_name (str): + table_name (str): The name of the table within the database. Methods: @@ -50,13 +51,13 @@ def add_score(self, network, checkpoint, threshold, scores_dict): Add a score entry to the database. Args: - network (str): + network (str): The name of the network. - checkpoint (int): + checkpoint (int): The checkpoint number. - threshold (float): + threshold (float): The threshold value. - scores_dict (dict): + scores_dict (dict): A dictionary containing scores. """ assert type(network) is str @@ -74,11 +75,11 @@ def get_scores(self, networks=None, checkpoints=None, thresholds=None): Retrieve scores from the database based on specified conditions. Args: - networks (str, list): + networks (str, list): The name or list of names of networks to filter on. - checkpoints (int, list): + checkpoints (int, list): The checkpoint number or list of checkpoint numbers to filter on. - thresholds (float, list): + thresholds (float, list): The threshold value or list of threshold values to filter on. Returns: diff --git a/src/autoseg/eval/evaluate.py b/src/autoseg/eval/evaluate.py index 7cfd552..304865c 100644 --- a/src/autoseg/eval/evaluate.py +++ b/src/autoseg/eval/evaluate.py @@ -23,15 +23,15 @@ def segment_and_validate( It logs information about the segmentation and validation process. Args: - model_checkpoint (str): + model_checkpoint (str): The checkpoint of the segmentation model to use (default is "latest"). - checkpoint_num (int): + checkpoint_num (int): The checkpoint number for the affinity model (default is 250000). - setup_num (str): + setup_num (str): The setup number for the affinity model (default is "1738"). Returns: - dict: + dict: A dictionary containing scores and evaluation metrics. """ logger.info( @@ -85,27 +85,27 @@ def validate( Validate segmentation results using specified parameters. Args: - checkpoint (str): + checkpoint (str): The checkpoint identifier. - threshold (float): + threshold (float): The threshold value. - offset (str): + offset (str): The offset for ROI (default is "3960,3960,3960"). - roi_shape (str): + roi_shape (str): The shape of ROI (default is "31680,31680,31680"). - skel (str): + skel (str): The path to the skeleton data file (default is "../../data/XPRESS_validation_skels.npz"). - zarr (str): + zarr (str): The path to the Zarr file for storing segmentation data (default is "./validation.zarr"). - h5 (str): + h5 (str): The path to the HDF5 file for storing validation data (default is "validation.h5"). - ds (str): + ds (str): The dataset name (default is "pred_seg"). - print_errors (bool): + print_errors (bool): Print errors during validation (default is False). - print_in_xyz (bool): + print_in_xyz (bool): Print coordinates in XYZ format (default is False). - downsample (int): + downsample (int): Downsample factor for evaluation (default is None). """ network = os.path.abspath(".").split(os.path.sep)[-1] diff --git a/src/autoseg/eval/metrics.py b/src/autoseg/eval/metrics.py index c34c914..aedc71f 100644 --- a/src/autoseg/eval/metrics.py +++ b/src/autoseg/eval/metrics.py @@ -18,13 +18,13 @@ def generate_graphs_with_seg_labels(segment_array, skeleton_path): the ROI but originally belonged to the same skeleton ID. Args: - segment_array (daisy.Array): + segment_array (daisy.Array): Array containing predicted segmentation labels. - skeleton_path (str): + skeleton_path (str): Path to the skeleton data file. Returns: - networkx.Graph: + networkx.Graph: Ground-truth graph with predicted labels added. """ gt_graph = np.load(skeleton_path, allow_pickle=True) @@ -66,13 +66,13 @@ def eval_erl(graph): Compute Expected Run Length (ERL) and normalized ERL for a given graph. Args: - graph (networkx.Graph): + graph (networkx.Graph): Graph representing the ground-truth. Returns: - float: + float: ERL value. - float: + float: Normalized ERL value. """ node_seg_lut = {} @@ -106,13 +106,13 @@ def build_segment_label_subgraph(segment_nodes, graph): Build a subgraph using a set of segment nodes from the given graph. Args: - segment_nodes (Iterable): + segment_nodes (Iterable): Nodes representing segments. - graph (networkx.Graph): + graph (networkx.Graph): Original graph. Returns: - networkx.Graph: + networkx.Graph: Subgraph containing specified segment nodes. """ subgraph = graph.subgraph(segment_nodes) @@ -137,15 +137,15 @@ def get_closest_node_pair_between_two_skeletons(skel1, skel2, graph): Get the closest pair of nodes between two skeletons in the given graph. Args: - skel1 (Iterable): + skel1 (Iterable): Nodes of the first skeleton. - skel2 (Iterable): + skel2 (Iterable): Nodes of the second skeleton. - graph (networkx.Graph): + graph (networkx.Graph): Original graph. Returns: - Tuple: + Tuple: Closest pair of nodes and their edge attributes. """ multiplier = (1, 1, 1) @@ -168,11 +168,11 @@ def find_merge_errors(graph): Find merge errors in the given graph. Args: - graph (networkx.Graph): + graph (networkx.Graph): Original graph. Returns: - set: + set: Set of merge errors. """ seg_dict = {} @@ -210,11 +210,11 @@ def get_split_merges(graph): Find split merges in the given graph. Args: - graph (networkx.Graph): + graph (networkx.Graph): Original graph. Returns: - set: + set: Set of split merges. """ # Count split errors. An error is an edge in the GT skeleton graph connecting two nodes @@ -232,11 +232,11 @@ def set_point_in_array(array, point_coord, val): Set a specific point in the array to a given value. Args: - array (daisy.Array): + array (daisy.Array): Target array. - point_coord (Tuple): + point_coord (Tuple): Coordinates of the point. - val: + val: Value to set. """ point_coord = daisy.Coordinate(point_coord) @@ -250,13 +250,13 @@ def make_voxel_gt_array(test_array, gt_graph): Rasterize ground-truth points to an empty array for computing Rand/VOI. Args: - test_array (daisy.Array): + test_array (daisy.Array): Target array. - gt_graph (networkx.Graph): + gt_graph (networkx.Graph): Ground-truth graph. Returns: - daisy.Array: + daisy.Array: Voxel array containing ground-truth information. """ gt_ndarray = np.zeros_like(test_array.data).astype(np.uint64) @@ -275,13 +275,13 @@ def get_voi(segment_array, gt_graph): Wrapper function to compute Rand/VOI scores. Args: - segment_array (daisy.Array): + segment_array (daisy.Array): Array containing predicted segmentation labels. - gt_graph (networkx.Graph): + gt_graph (networkx.Graph): Ground-truth graph. Returns: - Dict: + Dict: Dictionary containing Rand/VOI scores. """ voxel_gt = make_voxel_gt_array(segment_array, gt_graph) @@ -296,15 +296,15 @@ def run_eval(skeleton_file, segmentation_file, segmentation_ds, roi, downsamplin Run evaluation on the predicted segmentation. Args: - skeleton_file (str): + skeleton_file (str): Path to the skeleton data file. - segmentation_file (str): + segmentation_file (str): Path to the segmentation data file. - segmentation_ds (str): + segmentation_ds (str): Dataset name in the segmentation file. - roi (daisy.Roi): + roi (daisy.Roi): Region of interest. - downsampling (int): + downsampling (int): Downsample factor for evaluation. Returns: diff --git a/src/autoseg/train/ACLSDTrain.py b/src/autoseg/train/ACLSDTrain.py index 1f374ce..1a9538c 100644 --- a/src/autoseg/train/ACLSDTrain.py +++ b/src/autoseg/train/ACLSDTrain.py @@ -33,17 +33,17 @@ def aclsd_train( Train ACLSD model using Gunpowder library. Args: - raw_file (str): + raw_file (str): Path to the raw data file. - out_file (str): + out_file (str): Output path for saving predictions. - voxel_size (int): + voxel_size (int): Voxel size. - iterations (int): + iterations (int): Number of training iterations. - warmup (int): + warmup (int): Number of warm-up iterations. - save_every (int): + save_every (int): Save predictions every 'save_every' iterations. """ raw = gp.ArrayKey("RAW") diff --git a/src/autoseg/train/MTLSDTrain.py b/src/autoseg/train/MTLSDTrain.py index 70df5d2..d83a6dd 100644 --- a/src/autoseg/train/MTLSDTrain.py +++ b/src/autoseg/train/MTLSDTrain.py @@ -28,13 +28,13 @@ def mtlsd_train( Train MTLSD model using Gunpowder library. Args: - raw_file (str): + raw_file (str): Path to the raw data file. - voxel_size (int): + voxel_size (int): Voxel size. - iterations (int): + iterations (int): Number of training iterations. - save_every (int): + save_every (int): Save predictions every 'save_every' iterations. """ raw = gp.ArrayKey("RAW") diff --git a/src/autoseg/train/STELARRTrain.py b/src/autoseg/train/STELARRTrain.py index 0c4ae75..371f1a3 100644 --- a/src/autoseg/train/STELARRTrain.py +++ b/src/autoseg/train/STELARRTrain.py @@ -33,17 +33,17 @@ def stelarr_train( Train STELARR model using Gunpowder library. Args: - raw_file (str): + raw_file (str): Path to the raw data file. - out_file (str): + out_file (str): Path to the output file for raw predictions. - voxel_size (int): + voxel_size (int): Voxel size. - iterations (int): + iterations (int): Number of training iterations. - warmup (int): + warmup (int): Number of warm-up iterations. - save_every (int): + save_every (int): Save predictions every 'save_every' iterations. """ raw = gp.ArrayKey("RAW")