diff --git a/tango/sweeps.py b/tango/sweeps.py new file mode 100644 index 000000000..d1ffe5d36 --- /dev/null +++ b/tango/sweeps.py @@ -0,0 +1,68 @@ +import itertools +import json +import subprocess +from collections import OrderedDict +from dataclasses import dataclass + +from tango.common import Params, Registrable + + +class Sweeper(Registrable): + def __init__( + self, + main_config_path: str, + sweeps_config_path: str, + components: str, + ): + super(Registrable, self).__init__() + self.main_config_path = main_config_path + self.sweep_config = load_config(sweeps_config_path) + self.main_config_path = main_config_path + self.components = components + + # returns all the combinations of hyperparameters in the form of a list of lists + def get_combinations(self) -> list: + hyperparams = self.sweep_config.config["config"]["hyperparameters"] + hyperparams_lsts = [] + for val in hyperparams.values(): + hyperparams_lsts.append(val) + hyperparam_combos = list(itertools.product(*hyperparams_lsts)) + return hyperparam_combos + + # loops through all combinations of hyperparameters and creates a run for each + def run_experiments(self): + hyperparam_combos = self.get_combinations() + for combination in hyperparam_combos: + # main_config = self.override_hyperparameters(combination) + overrides = self.override_hyperparameters(combination) + # TODO: need to figure where & how to store results / way to track runs + # specify what workspace to use + subprocess.call( + [ + "tango", + "run", + self.main_config_path, + "--include-package", + self.components, + "--overrides", + json.dumps(overrides), + ] + ) + + # function to override all the hyperparameters in the current experiment_config + def override_hyperparameters(self, experiment_tuple: tuple) -> dict: + overrides = {} + for (i, key) in enumerate(self.sweep_config.config["config"]["hyperparameters"].keys()): + overrides[key] = experiment_tuple[i] + return overrides + + +# function that loads the config from a specified yaml or jasonnet file +def load_config(sweeps_config_path: str): + return SweepConfig.from_file(sweeps_config_path) + + +# data class that loads the parameters +@dataclass(frozen=True) +class SweepConfig(Params): + config: OrderedDict diff --git a/test_fixtures/sweeps/basic_test/basic_arithmetic.py b/test_fixtures/sweeps/basic_test/basic_arithmetic.py new file mode 100644 index 000000000..42aed8cf5 --- /dev/null +++ b/test_fixtures/sweeps/basic_test/basic_arithmetic.py @@ -0,0 +1,25 @@ +from typing import Union + +from tango import Step + +IntOrFloat = Union[int, float] + +@Step.register("addition") +class AdditionStep(Step): + def run(self, num1: int, num2: int) -> int: + return num1 + num2 + +@Step.register("scale_up") +class ScaleUp(Step): + def run(self, num1: int, factor: int) -> int: + return num1 * factor + +@Step.register("scale_down") +class ScaleDown(Step): + def run(self, num1: int, factor: int) -> IntOrFloat: + return num1 / factor + +@Step.register("print") +class Print(Step): + def run(self, num: IntOrFloat) -> None: + print(num) \ No newline at end of file diff --git a/test_fixtures/sweeps/basic_test/config.jsonnet b/test_fixtures/sweeps/basic_test/config.jsonnet new file mode 100644 index 000000000..dffd55793 --- /dev/null +++ b/test_fixtures/sweeps/basic_test/config.jsonnet @@ -0,0 +1,28 @@ +{ + "steps": { + "add_numbers": { + "type": "addition", + "num1": 34, + "num2": 8 + }, + "multiply_result": { + "type": "scale_up", + "num1": {"type": "ref", "ref": "add_numbers"}, + "factor": 10, + }, + "divide_result": { + "type": "scale_down", + "num1": {"type": "ref", "ref": "multiply_result"}, + "factor": 5, + }, + "add_x": { + "type": "addition", + "num1": {"type": "ref", "ref": "divide_result"}, + "num2": 1, + }, + "print": { + "type": "print", + "num": {"type": "ref", "ref": "add_x"}, + }, + }, +} diff --git a/test_fixtures/sweeps/basic_test/sweeps-config.jsonnet b/test_fixtures/sweeps/basic_test/sweeps-config.jsonnet new file mode 100644 index 000000000..cdb432357 --- /dev/null +++ b/test_fixtures/sweeps/basic_test/sweeps-config.jsonnet @@ -0,0 +1,13 @@ +{ + "config": { + "sweeper": "default", + "n_proc": 1, + "hyperparameters": { + "steps.add_numbers.num1": [8, 16], + "steps.add_numbers.num2": [2, 4], + "steps.multiply_result.factor": [1, 10], + "steps.divide_result.factor": [5, 10], + "steps.add_x.num2": [1, 2], + }, + }, +} \ No newline at end of file diff --git a/test_fixtures/sweeps/basic_test/sweeps-config.yaml b/test_fixtures/sweeps/basic_test/sweeps-config.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/test_fixtures/sweeps/cv_example_test/sweep-config.jsonnet b/test_fixtures/sweeps/cv_example_test/sweep-config.jsonnet new file mode 100644 index 000000000..415988705 --- /dev/null +++ b/test_fixtures/sweeps/cv_example_test/sweep-config.jsonnet @@ -0,0 +1,8 @@ +{ + sweeper: "default", + n_proc: 1, + hyperparameters: { + batch_size: [16, 32, 64], + lr: [0.01, 0.001], + }, +} \ No newline at end of file diff --git a/test_fixtures/sweeps/cv_example_test/sweep-config.yaml b/test_fixtures/sweeps/cv_example_test/sweep-config.yaml new file mode 100644 index 000000000..c029c8755 --- /dev/null +++ b/test_fixtures/sweeps/cv_example_test/sweep-config.yaml @@ -0,0 +1,10 @@ +sweeper: "default" +n_proc: 1 +hyperparameters: + batc_size: + - 16 + - 32 + - 64 + lr: + - 0.001 + - 0.0001 \ No newline at end of file