From eae129ddb0b61ddba68fe3329b4c0ed4732fe3f3 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 22:34:01 +0000 Subject: [PATCH 01/60] added __init__.py files to all folders --- benchmark/__init__.py | 0 benchmark/tpch/__init__.py | 0 dbms/__init__.py | 0 dbms/postgres/__init__.py | 0 misc/__init__.py | 0 tune/protox/agent/__init__.py | 0 tune/protox/agent/wolp/__init__.py | 0 tune/protox/embedding/__init__.py | 0 tune/protox/env/lsc/__init__.py | 0 tune/protox/env/mqo/__init__.py | 0 tune/protox/env/space/__init__.py | 0 tune/protox/env/target_reset/__init__.py | 0 tune/protox/env/util/__init__.py | 0 util/__init__.py | 0 14 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 benchmark/__init__.py create mode 100644 benchmark/tpch/__init__.py create mode 100644 dbms/__init__.py create mode 100644 dbms/postgres/__init__.py create mode 100644 misc/__init__.py create mode 100644 tune/protox/agent/__init__.py create mode 100644 tune/protox/agent/wolp/__init__.py create mode 100644 tune/protox/embedding/__init__.py create mode 100644 tune/protox/env/lsc/__init__.py create mode 100644 tune/protox/env/mqo/__init__.py create mode 100644 tune/protox/env/space/__init__.py create mode 100644 tune/protox/env/target_reset/__init__.py create mode 100644 tune/protox/env/util/__init__.py create mode 100644 util/__init__.py diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmark/tpch/__init__.py b/benchmark/tpch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbms/__init__.py b/dbms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbms/postgres/__init__.py b/dbms/postgres/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/misc/__init__.py b/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/agent/__init__.py b/tune/protox/agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/agent/wolp/__init__.py b/tune/protox/agent/wolp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/embedding/__init__.py b/tune/protox/embedding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/lsc/__init__.py b/tune/protox/env/lsc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/mqo/__init__.py b/tune/protox/env/mqo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/space/__init__.py b/tune/protox/env/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/target_reset/__init__.py b/tune/protox/env/target_reset/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/util/__init__.py b/tune/protox/env/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 00000000..e69de29b From 5eaf4bbc806a463eef3b41b1fe25bdb2e5436b92 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 22:40:04 +0000 Subject: [PATCH 02/60] fixed some simple type errors --- dbms/load_info_base_class.py | 2 +- tune/protox/embedding/vae.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dbms/load_info_base_class.py b/dbms/load_info_base_class.py index a5aec24e..99b1032e 100644 --- a/dbms/load_info_base_class.py +++ b/dbms/load_info_base_class.py @@ -8,7 +8,7 @@ class LoadInfoBaseClass: def get_schema_fpath(self) -> str: raise NotImplemented - def get_tables_and_fpaths(self) -> list[(str, str)]: + def get_tables_and_fpaths(self) -> list[tuple[str, str]]: raise NotImplemented # If the subclassing benchmark does not have constraints, you can return None here diff --git a/tune/protox/embedding/vae.py b/tune/protox/embedding/vae.py index c1a657ac..aa5df154 100644 --- a/tune/protox/embedding/vae.py +++ b/tune/protox/embedding/vae.py @@ -308,7 +308,7 @@ def init(layer: nn.Module) -> None: else: init_fn(layer.weight) - modules = [encoder, decoder] + modules: list[nn.Module] = [encoder, decoder] for module in modules: if module is not None: module.apply(init) From 5b94622f04c5fada9611b80691fa35a8c8cb65d1 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 22:50:59 +0000 Subject: [PATCH 03/60] fixed some more errors --- misc/utils.py | 19 ++++++++----------- tune/protox/env/types.py | 6 ++++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/misc/utils.py b/misc/utils.py index 4a78c352..97356ad7 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -1,13 +1,11 @@ import os import shutil import subprocess -import sys from datetime import datetime from enum import Enum from pathlib import Path -from typing import Tuple +from typing import Tuple, Optional -import click import redis import yaml @@ -47,8 +45,7 @@ def get_runs_path_from_workspace_path(workspace_path): def get_scale_factor_string(scale_factor: float | str) -> str: - assert type(scale_factor) is float or type(scale_factor) is str - if scale_factor == SCALE_FACTOR_PLACEHOLDER: + if type(scale_factor) is str and scale_factor == SCALE_FACTOR_PLACEHOLDER: return scale_factor else: if float(int(scale_factor)) == scale_factor: @@ -62,9 +59,9 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: # Other parameters -BENCHMARK_NAME_PLACEHOLDER = "[benchmark_name]" -WORKLOAD_NAME_PLACEHOLDER = "[workload_name]" -SCALE_FACTOR_PLACEHOLDER = "[scale_factor]" +BENCHMARK_NAME_PLACEHOLDER: str = "[benchmark_name]" +WORKLOAD_NAME_PLACEHOLDER: str = "[workload_name]" +SCALE_FACTOR_PLACEHOLDER: str = "[scale_factor]" # Paths of config files in the codebase. These are always relative paths. # The reason these can be relative paths instead of functions taking in codebase_path as input is because relative paths are relative to the codebase root @@ -481,7 +478,7 @@ def extract_from_task_run_fordpath( # TODO(phw2): really look at the clean PR to see what it changed # TODO(phw2): after merging agent-train, refactor some code in agent-train to use save_file() instead of open_and_save() -def save_file(dbgym_cfg: DBGymConfig, fpath: Path) -> Path: +def save_file(dbgym_cfg: DBGymConfig, fpath: Path) -> None: """ If an external function takes in a file/directory as input, you will not be able to call open_and_save(). In these situations, just call save_file(). @@ -544,7 +541,7 @@ def save_file(dbgym_cfg: DBGymConfig, fpath: Path) -> Path: # TODO(phw2): refactor our manual symlinking in postgres/cli.py to use link_result() instead def link_result( - dbgym_cfg: DBGymConfig, result_fordpath: Path, custom_result_name: str | None = None + dbgym_cfg: DBGymConfig, result_fordpath: Path, custom_result_name: Optional[str] = None ) -> Path: """ result_fordpath must be a "result", meaning it was generated inside dbgym_cfg.dbgym_this_run_path. @@ -564,7 +561,7 @@ def link_result( assert is_child_path(result_fordpath, dbgym_cfg.dbgym_this_run_path) assert not os.path.islink(result_fordpath) - if custom_result_name != None: + if type(custom_result_name) is str: result_name = custom_result_name else: if os.path.isfile(result_fordpath): diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index 976317ed..846a4762 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -56,10 +56,12 @@ class ServerIndexMetadata(TypedDict, total=False): ServerTableIndexMetadata = NewType( "ServerTableIndexMetadata", dict[str, dict[str, ServerIndexMetadata]] ) -ProtoAction = NewType("ProtoAction", torch.Tensor) +class ProtoAction(torch.Tensor): + pass KnobMap = NewType("KnobMap", dict[str, Union[Knob, CategoricalKnob]]) -KnobSpaceRawAction = NewType("KnobSpaceRawAction", torch.Tensor) +class KnobSpaceRawAction(torch.Tensor): + pass # {knob.name(): knob_value, ...} KnobSpaceAction = NewType("KnobSpaceAction", dict[str, Any]) # {knob.name(): knob_value, ...} From 3500ba8ca93f4f700e74e157ebbd0b3262afdb5f Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 22:56:37 +0000 Subject: [PATCH 04/60] fixed tune/protox/env/logger.py --- tune/protox/env/logger.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tune/protox/env/logger.py b/tune/protox/env/logger.py index 627e6a3c..07459a18 100644 --- a/tune/protox/env/logger.py +++ b/tune/protox/env/logger.py @@ -9,7 +9,7 @@ import numpy as np from plumbum import local -from torch.utils.tensorboard import SummaryWriter # type: ignore +from torch.utils.tensorboard import SummaryWriter from typing_extensions import ParamSpec from misc.utils import DBGymConfig @@ -25,16 +25,17 @@ def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> T: ret = f(*args, **kwargs) # TODO(wz2): This is a hack to get a logger instance. - assert hasattr(args[0], "logger"), print(args[0], type(args[0])) + first_arg = args[0] # type: ignore[index] # Ignore the indexing type error + assert hasattr(first_arg, "logger"), print(first_arg, type(first_arg)) - if args[0].logger is None: + if first_arg.logger is None: # If there is no logger, just return. return ret - assert isinstance(args[0].logger, Logger) - if args[0].logger is not None: - cls_name = type(args[0]).__name__ - args[0].logger.record(f"{cls_name}_{key}", time.time() - start) + assert isinstance(first_arg.logger, Logger) + if first_arg.logger is not None: + cls_name = type(first_arg).__name__ + first_arg.logger.record(f"{cls_name}_{key}", time.time() - start) return ret return wrapped_f @@ -81,7 +82,7 @@ def __init__( self.writer: Union[SummaryWriter, None] = None if self.trace: self.tensorboard_dpath.mkdir(parents=True, exist_ok=True) - self.writer = SummaryWriter(self.tensorboard_dpath) # type: ignore + self.writer = SummaryWriter(self.tensorboard_dpath) self.iteration = 1 self.iteration_data: dict[str, Any] = {} @@ -144,14 +145,14 @@ def advance(self) -> None: for key, value in self.iteration_data.items(): if isinstance(value, str): # str is considered a np.ScalarType - self.writer.add_text(key, value, self.iteration) # type: ignore + self.writer.add_text(key, value, self.iteration) else: - self.writer.add_scalar(key, value, self.iteration) # type: ignore + self.writer.add_scalar(key, value, self.iteration) del self.iteration_data self.iteration_data = {} self.iteration += 1 - self.writer.flush() # type: ignore + self.writer.flush() def record(self, key: str, value: Any) -> None: stack = inspect.stack(context=2) @@ -168,4 +169,4 @@ def flush(self) -> None: if self.trace: assert self.writer self.advance() - self.writer.flush() # type: ignore + self.writer.flush() From 4bf44adfc565a410bba05d6a8038736dac8a486a Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 22:59:43 +0000 Subject: [PATCH 05/60] fixed task.py --- scripts/mypy.ini | 3 +++ task.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 scripts/mypy.ini diff --git a/scripts/mypy.ini b/scripts/mypy.ini new file mode 100644 index 00000000..e68d5fa4 --- /dev/null +++ b/scripts/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +strict = True +ignore_missing_imports = True \ No newline at end of file diff --git a/task.py b/task.py index 7871fdc4..37ac3a69 100644 --- a/task.py +++ b/task.py @@ -16,7 +16,7 @@ @click.group() @click.pass_context -def task(ctx): +def task(ctx: click.Context) -> None: """💩💩💩 CMU-DB Database Gym: github.com/cmu-db/dbgym 💩💩💩""" dbgym_config_path = Path(os.getenv("DBGYM_CONFIG_PATH", "dbgym_config.yaml")) ctx.obj = DBGymConfig(dbgym_config_path) From 36eb67a2a2b08a51842203abe3b37df17e35f215 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 23:00:29 +0000 Subject: [PATCH 06/60] fixed tune/cli.py --- tune/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tune/cli.py b/tune/cli.py index 7d1f98f1..5b5dec02 100644 --- a/tune/cli.py +++ b/tune/cli.py @@ -6,7 +6,7 @@ @click.group(name="tune") @click.pass_obj -def tune_group(dbgym_cfg: DBGymConfig): +def tune_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("tune") From 2c32b4f22ac57c66b0d1ee432d23152eae342091 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 23:00:57 +0000 Subject: [PATCH 07/60] fixed tune/protox/cli.py --- tune/protox/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tune/protox/cli.py b/tune/protox/cli.py index 15160f89..4c09f383 100644 --- a/tune/protox/cli.py +++ b/tune/protox/cli.py @@ -7,7 +7,7 @@ @click.group(name="protox") @click.pass_obj -def protox_group(dbgym_cfg: DBGymConfig): +def protox_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("protox") From 06e1e40e5a1adb257d9d039851643ca5dc843fd6 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 23:01:18 +0000 Subject: [PATCH 08/60] fixed tune/protox/agent/cli.py --- tune/protox/agent/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tune/protox/agent/cli.py b/tune/protox/agent/cli.py index 98f7bb22..fcc85ee1 100644 --- a/tune/protox/agent/cli.py +++ b/tune/protox/agent/cli.py @@ -8,7 +8,7 @@ @click.group("agent") @click.pass_obj -def agent_group(dbgym_cfg: DBGymConfig): +def agent_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("agent") From 3081e75f51e3aba12570a1e118634a7318e37bae Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 23:04:35 +0000 Subject: [PATCH 09/60] fixed tune/protox/agent/tune.py --- misc/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/misc/utils.py b/misc/utils.py index 97356ad7..fd4749ae 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -4,7 +4,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import Tuple, Optional +from typing import Callable, Tuple, Optional import redis import yaml @@ -78,7 +78,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: ) # Generally useful functions -workload_name_fn = ( +workload_name_fn: Callable[[float | str, int, int, str], str] = ( lambda scale_factor, seed_start, seed_end, query_subset: f"workload_sf{get_scale_factor_string(scale_factor)}_{seed_start}_{seed_end}_{query_subset}" ) @@ -87,13 +87,13 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: traindata_fname = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_embedding_traindata.parquet" ) -default_embedder_dname = ( +default_embedder_dname: Callable[[str, str], str] = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_embedder" ) default_hpoed_agent_params_fname = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_hpoed_agent_params.json" ) -default_tuning_steps_dname = ( +default_tuning_steps_dname: Callable[[str, str, bool], str] = ( lambda benchmark_name, workload_name, boot_enabled_during_tune: f"{benchmark_name}_{workload_name}{'_boot' if boot_enabled_during_tune else ''}_tuning_steps" ) @@ -126,7 +126,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (default_embedder_dname(benchmark_name, workload_name) + ".link") ) -default_hpoed_agent_params_path = ( +default_hpoed_agent_params_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) From 5b5d50bce803113c546198c752092ab84fe703ae Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sun, 1 Sep 2024 23:36:20 +0000 Subject: [PATCH 10/60] fixed tune/protox/agent/hpo.py --- misc/utils.py | 20 ++--- tune/protox/agent/hpo.py | 190 +++++++++++++++++++++------------------ tune/protox/env/types.py | 6 +- 3 files changed, 113 insertions(+), 103 deletions(-) diff --git a/misc/utils.py b/misc/utils.py index fd4749ae..c134b112 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -54,7 +54,7 @@ def get_scale_factor_string(scale_factor: float | str) -> str: return str(scale_factor).replace(".", "point") -def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: +def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float | str) -> str: return f"{benchmark_name}_sf{get_scale_factor_string(scale_factor)}_pristine_dbdata.tgz" @@ -68,11 +68,11 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: DEFAULT_HPO_SPACE_PATH = PROTOX_EMBEDDING_PATH / "default_hpo_space.json" DEFAULT_SYSKNOBS_PATH = PROTOX_AGENT_PATH / "default_sysknobs.yaml" DEFAULT_BOOT_CONFIG_FPATH = POSTGRES_PATH / "default_boot_config.yaml" -default_benchmark_config_path = ( +default_benchmark_config_path: Callable[[str], Path] = ( lambda benchmark_name: PROTOX_PATH / f"default_{benchmark_name}_benchmark_config.yaml" ) -default_benchbase_config_path = ( +default_benchbase_config_path: Callable[[str], Path] = ( lambda benchmark_name: PROTOX_PATH / f"default_{benchmark_name}_benchbase_config.xml" ) @@ -90,7 +90,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: default_embedder_dname: Callable[[str, str], str] = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_embedder" ) -default_hpoed_agent_params_fname = ( +default_hpoed_agent_params_fname: Callable[[str, str], str] = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_hpoed_agent_params.json" ) default_tuning_steps_dname: Callable[[str, str, bool], str] = ( @@ -118,7 +118,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (traindata_fname(benchmark_name, workload_name) + ".link") ) -default_embedder_path = ( +default_embedder_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) @@ -134,7 +134,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (default_hpoed_agent_params_fname(benchmark_name, workload_name) + ".link") ) -default_workload_path = ( +default_workload_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) @@ -142,7 +142,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (workload_name + ".link") ) -default_pristine_dbdata_snapshot_path = ( +default_pristine_dbdata_snapshot_path: Callable[[Path, str, float | str], Path] = ( lambda workspace_path, benchmark_name, scale_factor: get_symlinks_path_from_workspace_path( workspace_path ) @@ -150,10 +150,10 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (get_dbdata_tgz_name(benchmark_name, scale_factor) + ".link") ) -default_dbdata_parent_dpath = lambda workspace_path: get_tmp_path_from_workspace_path( - workspace_path +default_dbdata_parent_dpath: Callable[[Path], Path] = ( + lambda workspace_path: get_tmp_path_from_workspace_path(workspace_path) ) -default_pgbin_path = ( +default_pgbin_path: Callable[[Path], Path] = ( lambda workspace_path: get_symlinks_path_from_workspace_path(workspace_path) / "dbgym_dbms_postgres" / "build" diff --git a/tune/protox/agent/hpo.py b/tune/protox/agent/hpo.py index 05ca46ef..8fdf860c 100644 --- a/tune/protox/agent/hpo.py +++ b/tune/protox/agent/hpo.py @@ -6,7 +6,7 @@ import time from datetime import datetime from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, Type, Union import click import numpy as np @@ -54,26 +54,26 @@ class AgentHPOArgs: def __init__( self, - benchmark_name, - workload_name, - embedder_path, - benchmark_config_path, - benchbase_config_path, - sysknobs_path, - pristine_dbdata_snapshot_path, - dbdata_parent_dpath, - pgbin_path, - workload_path, - seed, - agent, - max_concurrent, - num_samples, - tune_duration_during_hpo, - workload_timeout, - query_timeout, - enable_boot_during_hpo, - boot_config_fpath_during_hpo, - build_space_good_for_boot, + benchmark_name: str, + workload_name: str, + embedder_path: Path, + benchmark_config_path: Path, + benchbase_config_path: Path, + sysknobs_path: Path, + pristine_dbdata_snapshot_path: Path, + dbdata_parent_dpath: Path, + pgbin_path: Path, + workload_path: Path, + seed: int, + agent: str, + max_concurrent: int, + num_samples: int, + tune_duration_during_hpo: float, + workload_timeout: float, + query_timeout: float, + enable_boot_during_hpo: bool, + boot_config_fpath_during_hpo: Path, + build_space_good_for_boot: bool, ): self.benchmark_name = benchmark_name self.workload_name = workload_name @@ -119,35 +119,38 @@ def __init__( ) @click.option( "--scale-factor", + type=float, default=1.0, help=f"The scale factor used when generating the data of the benchmark.", ) @click.option( "--embedder-path", + type=Path, default=None, help=f"The path to the directory that contains an `embedder.pth` file with a trained encoder and decoder as well as a `config` file. The default is {default_embedder_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}", ) @click.option( "--benchmark-config-path", - default=None, type=Path, + default=None, help=f"The path to the .yaml config file for the benchmark. The default is {default_benchmark_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--benchbase-config-path", - default=None, type=Path, + default=None, help=f"The path to the .xml config file for BenchBase, used to run OLTP workloads. The default is {default_benchbase_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--sysknobs-path", + type=Path, default=DEFAULT_SYSKNOBS_PATH, help=f"The path to the file configuring the space of system knobs the tuner can tune.", ) @click.option( "--pristine-dbdata-snapshot-path", - default=None, type=Path, + default=None, help=f"The path to the .tgz snapshot of the dbdata directory to use as a starting point for tuning. The default is {default_pristine_dbdata_snapshot_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER)}.", ) @click.option( @@ -158,57 +161,62 @@ def __init__( ) @click.option( "--dbdata-parent-dpath", - default=None, type=Path, + default=None, help=f"The path to the parent directory of the dbdata which will be actively tuned. The default is {default_dbdata_parent_dpath(WORKSPACE_PATH_PLACEHOLDER)}.", ) @click.option( "--pgbin-path", - default=None, type=Path, + default=None, help=f"The path to the bin containing Postgres executables. The default is {default_pgbin_path(WORKSPACE_PATH_PLACEHOLDER)}.", ) @click.option( "--workload-path", - default=None, type=Path, + default=None, help=f"The path to the directory that specifies the workload (such as its queries and order of execution). The default is {default_workload_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.", ) @click.option( "--seed", - default=None, type=int, + default=None, help="The seed used for all sources of randomness (random, np, torch, etc.). The default is a random value.", ) @click.option( - "--agent", default="wolp", help=f"The RL algorithm to use for the tuning agent." + "--agent", + type=str, + default="wolp", + help=f"The RL algorithm to use for the tuning agent." ) @click.option( "--max-concurrent", + type=int, default=1, help=f"The max # of concurrent agent models to train. Note that unlike in HPO, all will use the same hyperparameters. This just helps control for other sources of randomness.", ) @click.option( "--num-samples", + type=int, default=40, help=f"The # of times to specific hyperparameter configs to sample from the hyperparameter search space and train agent models with.", ) @click.option( "--tune-duration-during-hpo", - default=4, type=float, + default=4.0, help="The number of hours to run each hyperparamer config tuning trial for.", ) @click.option( "--workload-timeout", - default=DEFAULT_WORKLOAD_TIMEOUT, type=int, + default=DEFAULT_WORKLOAD_TIMEOUT, help="The timeout (in seconds) of a workload. We run the workload once per DBMS configuration. For OLAP workloads, certain configurations may be extremely suboptimal, so we need to time out the workload.", ) @click.option( "--query-timeout", - default=30, type=int, + default=30, help="The timeout (in seconds) of a query. See the help of --workload-timeout for the motivation of this.", ) @click.option( @@ -218,8 +226,8 @@ def __init__( ) @click.option( "--boot-config-fpath-during-hpo", - default=DEFAULT_BOOT_CONFIG_FPATH, type=Path, + default=DEFAULT_BOOT_CONFIG_FPATH, help="The path to the file configuring Boot when running HPO. When tuning, you may use a different Boot config.", ) # Building a space good for Boot is subtly different from whether we enable Boot during HPO. @@ -240,58 +248,58 @@ def __init__( help="Whether to avoid certain options that are known to not perform well when Boot is enabled. See the codebase for why this is subtly different from --enable-boot-during-hpo.", ) def hpo( - dbgym_cfg, - benchmark_name, - seed_start, - seed_end, - query_subset, - scale_factor, - embedder_path, - benchmark_config_path, - benchbase_config_path, - sysknobs_path, - pristine_dbdata_snapshot_path, - intended_dbdata_hardware, - dbdata_parent_dpath, - pgbin_path, - workload_path, - seed, - agent, - max_concurrent, - num_samples, - tune_duration_during_hpo, - workload_timeout, - query_timeout, + dbgym_cfg: DBGymConfig, + benchmark_name: str, + seed_start: int, + seed_end: int, + query_subset: str, + scale_factor: float, + embedder_path: Optional[Path], + benchmark_config_path: Optional[Path], + benchbase_config_path: Optional[Path], + sysknobs_path: Path, + pristine_dbdata_snapshot_path: Optional[Path], + intended_dbdata_hardware: str, + dbdata_parent_dpath: Optional[Path], + pgbin_path: Optional[Path], + workload_path: Optional[Path], + seed: Optional[int], + agent: str, + max_concurrent: int, + num_samples: int, + tune_duration_during_hpo: float, + workload_timeout: int, + query_timeout: int, enable_boot_during_hpo: bool, boot_config_fpath_during_hpo: Path, build_space_good_for_boot: bool, -): +) -> None: # Set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if embedder_path == None: + if embedder_path is None: embedder_path = default_embedder_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) - if benchmark_config_path == None: + if benchmark_config_path is None: benchmark_config_path = default_benchmark_config_path(benchmark_name) - if benchbase_config_path == None: + if benchbase_config_path is None: benchbase_config_path = default_benchbase_config_path(benchmark_name) - if pristine_dbdata_snapshot_path == None: + if pristine_dbdata_snapshot_path is None: pristine_dbdata_snapshot_path = default_pristine_dbdata_snapshot_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, scale_factor ) - if dbdata_parent_dpath == None: + if dbdata_parent_dpath is None: dbdata_parent_dpath = default_dbdata_parent_dpath( dbgym_cfg.dbgym_workspace_path ) - if pgbin_path == None: + if pgbin_path is None: pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path) - if workload_path == None: + if workload_path is None: workload_path = default_workload_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) - if seed == None: - seed = random.randint(0, 1e8) + if seed is None: + seed = random.randint(0, int(1e8)) # Convert all input paths to absolute paths embedder_path = conv_inputpath_to_realabspath(dbgym_cfg, embedder_path) @@ -358,15 +366,15 @@ def build_space( benchmark_config: dict[str, Any], workload_path: Path, embedder_path: list[Path], - pgconn_info: dict[str, str], + pgconn_info: dict[str, Path], benchbase_config: dict[str, Any] = {}, - tune_duration_during_hpo: int = 30, + tune_duration_during_hpo: float = 30.0, seed: int = 0, enable_boot_during_hpo: bool = False, - boot_config_fpath_during_hpo: Path = None, + boot_config_fpath_during_hpo: Path = Path(), build_space_good_for_boot: bool = False, - workload_timeouts: list[int] = [600], - query_timeouts: list[int] = [30], + workload_timeouts: list[float] = [600.0], + query_timeouts: list[float] = [30.0], ) -> dict[str, Any]: return { @@ -551,7 +559,7 @@ def setup(self, hpo_params: dict[str, Any]) -> None: # Attach mythril directory to the search path. sys.path.append(os.path.expanduser(self.dbgym_cfg.dbgym_repo_path)) - torch.set_default_dtype(torch.float32) # type: ignore + torch.set_default_dtype(torch.float32) seed = ( hpo_params["seed"] if hpo_params["seed"] != -1 @@ -640,7 +648,7 @@ def step(self) -> dict[Any, Any]: def cleanup(self) -> None: self.logger.flush() - self.env.close() # type: ignore + self.env.close() if Path(self.signal).exists(): os.remove(self.signal) @@ -650,11 +658,12 @@ def cleanup(self) -> None: # Using a function to create a class is Ray's recommended way of doing this (see # https://discuss.ray.io/t/using-static-variables-to-control-trainable-subclass-in-ray-tune/808/4) # If you don't create the class with a function, it doesn't work due to how Ray serializes classes -def create_tune_opt_class(dbgym_cfg_param): +global_dbgym_cfg: DBGymConfig +def create_tune_opt_class(dbgym_cfg_param: DBGymConfig) -> Type[Trainable]: global global_dbgym_cfg global_dbgym_cfg = dbgym_cfg_param - class TuneOpt(Trainable): + class TuneOpt(Trainable): # type: ignore dbgym_cfg = global_dbgym_cfg def setup(self, hpo_params: dict[str, Any]) -> None: @@ -697,20 +706,23 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None: workload_timeouts = [hpo_args.workload_timeout] query_timeouts = [hpo_args.query_timeout] - benchbase_config = ( - { - "oltp_config": { - "oltp_num_terminals": hpo_args.oltp_num_terminals, - "oltp_duration": hpo_args.oltp_duration, - "oltp_sf": hpo_args.oltp_sf, - "oltp_warmup": hpo_args.oltp_warmup, - }, - "benchbase_path": hpo_args.benchbase_path, - "benchbase_config_path": hpo_args.benchbase_config_path, - } - if is_oltp - else {} - ) + assert not is_oltp + benchbase_config: dict[str, Any] = {} + # This is commented out because OLTP is currently not implemented. + # benchbase_config = ( + # { + # "oltp_config": { + # "oltp_num_terminals": hpo_args.oltp_num_terminals, + # "oltp_duration": hpo_args.oltp_duration, + # "oltp_sf": hpo_args.oltp_sf, + # "oltp_warmup": hpo_args.oltp_warmup, + # }, + # "benchbase_path": hpo_args.benchbase_path, + # "benchbase_config_path": hpo_args.benchbase_config_path, + # } + # if is_oltp + # else {} + # ) space = build_space( sysknobs, @@ -738,7 +750,7 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None: ) # Scheduler. - scheduler = FIFOScheduler() # type: ignore + scheduler = FIFOScheduler() # Search. search = BasicVariantGenerator(max_concurrent=hpo_args.max_concurrent) diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index 846a4762..4fe5f2b9 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -56,12 +56,10 @@ class ServerIndexMetadata(TypedDict, total=False): ServerTableIndexMetadata = NewType( "ServerTableIndexMetadata", dict[str, dict[str, ServerIndexMetadata]] ) -class ProtoAction(torch.Tensor): - pass +ProtoAction = NewType("ProtoAction", torch.Tensor) # type: ignore KnobMap = NewType("KnobMap", dict[str, Union[Knob, CategoricalKnob]]) -class KnobSpaceRawAction(torch.Tensor): - pass +KnobSpaceRawAction = NewType("KnobSpaceRawAction", torch.Tensor) # type: ignore # {knob.name(): knob_value, ...} KnobSpaceAction = NewType("KnobSpaceAction", dict[str, Any]) # {knob.name(): knob_value, ...} From 6a0468a5deb7bfec1f127ab0ef620f86a664f558 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 00:00:37 +0000 Subject: [PATCH 11/60] fixed tune/protox/agent/replay.py --- misc/utils.py | 2 +- tune/protox/agent/replay.py | 65 ++++++++++++++++----------------- tune/protox/env/types.py | 11 ++++-- tune/protox/env/util/pg_conn.py | 2 +- 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/misc/utils.py b/misc/utils.py index c134b112..229bc075 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -163,7 +163,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float | str) -> str: / "postgres" / "bin" ) -default_tuning_steps_dpath = ( +default_tuning_steps_dpath: Callable[[Path, str, str, bool], Path] = ( lambda workspace_path, benchmark_name, workload_name, boot_enabled_during_tune: get_symlinks_path_from_workspace_path( workspace_path ) diff --git a/tune/protox/agent/replay.py b/tune/protox/agent/replay.py index 6c59ba5c..cca69911 100644 --- a/tune/protox/agent/replay.py +++ b/tune/protox/agent/replay.py @@ -6,10 +6,12 @@ replayed tuning run is not. """ +from datetime import datetime import json import logging import pickle from pathlib import Path +from typing import Any, Optional, Set import click import pandas as pd @@ -28,8 +30,9 @@ from tune.protox.agent.build_trial import build_trial from tune.protox.env.pg_env import PostgresEnv from tune.protox.env.space.holon_space import HolonSpace +from tune.protox.env.space.primitive.index import IndexAction from tune.protox.env.space.utils import fetch_server_indexes, fetch_server_knobs -from tune.protox.env.types import HolonAction +from tune.protox.env.types import ActionsInfo, HolonAction from tune.protox.env.workload import Workload REPLAY_DATA_FNAME = "replay_data.csv" @@ -38,11 +41,13 @@ class ReplayArgs: def __init__( self, - workload_timeout_during_replay: bool, + # If it's None, it'll get set later on inside replay_tuning_run(). + workload_timeout_during_replay: Optional[float], replay_all_variations: bool, simulated: bool, - cutoff: float, - blocklist: list, + # If it's None, it'll get set later on inside replay_tuning_run(). + cutoff: Optional[float], + blocklist: list[str], ): self.workload_timeout_during_replay = workload_timeout_during_replay self.replay_all_variations = replay_all_variations @@ -73,6 +78,7 @@ def __init__( ) @click.option( "--scale-factor", + type=float, default=1.0, help="The scale factor used when generating the data of the benchmark.", ) @@ -83,14 +89,14 @@ def __init__( ) @click.option( "--tuning-steps-dpath", - default=None, type=Path, + default=None, help="The path to the `tuning_steps` directory to be replayed.", ) @click.option( "--workload-timeout-during-replay", + type=float, default=None, - type=int, # You can make it use the workload timeout used during tuning if you want. # I just made it use the workload timeout from HPO because I don't currently persist the tuning HPO params. help="The timeout (in seconds) of a workload when replaying. By default, it will be equal to the workload timeout used during HPO.", @@ -107,14 +113,14 @@ def __init__( ) @click.option( "--cutoff", - default=None, type=float, + default=None, help='Only evaluate configs up to cutoff hours. None means "evaluate all configs".', ) @click.option( "--blocklist", + type=list[str], default=[], - type=list, help="Ignore running queries in the blocklist.", ) def replay( @@ -125,17 +131,17 @@ def replay( query_subset: str, scale_factor: float, boot_enabled_during_tune: bool, - tuning_steps_dpath: Path, - workload_timeout_during_replay: bool, + tuning_steps_dpath: Optional[Path], + workload_timeout_during_replay: Optional[float], replay_all_variations: bool, simulated: bool, - cutoff: float, - blocklist: list, + cutoff: Optional[float], + blocklist: list[str], ) -> None: # Set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if tuning_steps_dpath == None: + if tuning_steps_dpath is None: tuning_steps_dpath = default_tuning_steps_dpath( dbgym_cfg.dbgym_workspace_path, benchmark_name, @@ -161,7 +167,7 @@ def replay( def replay_tuning_run( dbgym_cfg: DBGymConfig, tuning_steps_dpath: Path, replay_args: ReplayArgs -): +) -> None: """ Replay a single tuning run (as in one tuning_steps/ folder). """ @@ -174,7 +180,7 @@ def _is_tuning_step_line(line: str) -> bool: hpo_params = json.load(f) # Set defaults that depend on hpo_params - if replay_args.workload_timeout_during_replay == None: + if replay_args.workload_timeout_during_replay is None: replay_args.workload_timeout_during_replay = hpo_params["workload_timeout"][ str(TuningMode.HPO) ] @@ -190,6 +196,7 @@ def _is_tuning_step_line(line: str) -> bool: # This finds all the [time] folders in tuning_steps/ (except "baseline" since we ignore that in `_is_tuning_step_line()`), # so you could just do `ls tuning_steps/` if you wanted to. folders = [] + start_time: Optional[datetime] = None start_found = False output_log_fpath = tuning_steps_dpath / "output.log" with open_and_save(dbgym_cfg, output_log_fpath, "r") as f: @@ -209,8 +216,9 @@ def _is_tuning_step_line(line: str) -> bool: time_since_start = parse( line.split("DEBUG:")[-1].split(" Running")[0].split("[")[0] ) + assert type(start_time) is datetime if ( - replay_args.cutoff == None + replay_args.cutoff is None or (time_since_start - start_time).total_seconds() < replay_args.cutoff * 3600 ): @@ -241,7 +249,7 @@ def _is_tuning_step_line(line: str) -> bool: num_lines += 1 # A convenience wrapper around execute_workload() which fills in the arguments properly and processes the return values. - def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: + def _execute_workload_wrapper(actions_info: ActionsInfo) -> tuple[int, int, bool, float]: logging.info( f"\n\nfetch_server_knobs(): {fetch_server_knobs(pg_env.pg_conn.conn(), action_space.get_knob_space().tables, action_space.get_knob_space().knobs, pg_env.workload.queries)}\n\n" ) @@ -299,8 +307,7 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: current_step = 0 start_found = False start_time = None - maximal_repo = None - existing_index_acts = [] + existing_index_acts: Set[IndexAction] = set() for line in f: # Keep going until we've found the start. @@ -316,19 +323,10 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: continue elif _is_tuning_step_line(line): - if _is_tuning_step_line(line): - repo = eval(line.split("Running ")[-1])[-1] - time_since_start = parse( - line.split("DEBUG:")[-1].split(" Running")[0].split("[")[0] - ) - elif "Found new maximal state with" in line: - repo = eval(maximal_repo.split("Running ")[-1])[-1] - time_since_start = parse( - maximal_repo.split("DEBUG:")[-1] - .split(" Running")[0] - .split("[")[0] - ) - maximal_repo = None + repo = eval(line.split("Running ")[-1])[-1] + time_since_start = parse( + line.split("DEBUG:")[-1].split(" Running")[0].split("[")[0] + ) # Get the original runtime as well as whether any individual queries and/or the full workload timed out. run_raw_csv_fpath = tuning_steps_dpath / repo / "run.raw.csv" @@ -367,7 +365,7 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: with open_and_save( dbgym_cfg, tuning_steps_dpath / repo / "action.pkl", "rb" ) as f: - actions_info = pickle.load(f) + actions_info: ActionsInfo = pickle.load(f) all_holon_action_variations = actions_info[ "all_holon_action_variations" ] @@ -451,6 +449,7 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: ) # Perform some validity checks and then add this tuning step's data to `run_data``. + assert isinstance(start_time, datetime) this_step_run_data = { "step": current_step, "time_since_start": (time_since_start - start_time).total_seconds(), diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index 4fe5f2b9..17c963a1 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -166,6 +166,11 @@ class QuerySpec(TypedDict, total=False): tbl_fold_iterations: int +class ActionsInfo(TypedDict): + all_holon_action_variations: list[Tuple[str, HolonAction]] + best_observed_holon_action: HolonAction + + class EnvInfoDict(TypedDict, total=False): # Original baseline metric. baseline_metric: float @@ -193,12 +198,12 @@ class EnvInfoDict(TypedDict, total=False): # Query metric data. query_metric_data: Optional[dict[str, BestQueryRun]] # Information about the actions that were executed this step. - # The actions are in a format usable by replay. (TODO(phw2)) - actions_info: Tuple["KnobSpaceAction", "IndexAction", "QuerySpaceAction"] + # The actions are in a format usable by replay. + actions_info: Optional[ActionsInfo] # ProtoAction of the altered step action. maximal_embed: ProtoAction # New state container. state_container: HolonStateContainer # What the LSC associated with the action is. - lsc: float + lsc: float \ No newline at end of file diff --git a/tune/protox/env/util/pg_conn.py b/tune/protox/env/util/pg_conn.py index 233b49bc..94da1ad3 100644 --- a/tune/protox/env/util/pg_conn.py +++ b/tune/protox/env/util/pg_conn.py @@ -358,7 +358,7 @@ def cancel_fn(conn_str: str) -> None: self.disconnect() return 0, None - def restore_pristine_snapshot(self): + def restore_pristine_snapshot(self) -> None: self._restore_snapshot(self.pristine_dbdata_snapshot_fpath) def restore_checkpointed_snapshot(self): From 9318b8cfc2b6cb7ca59dd614b0a628347052abc1 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 16:32:06 +0000 Subject: [PATCH 12/60] added mypy to req.txt --- dependencies/requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dependencies/requirements.txt b/dependencies/requirements.txt index ba32594c..1be1b2b0 100644 --- a/dependencies/requirements.txt +++ b/dependencies/requirements.txt @@ -1,6 +1,7 @@ absl-py==2.1.0 aiosignal==1.3.1 astunparse==1.6.3 +async-timeout==4.0.3 attrs==23.2.0 black==24.2.0 cachetools==5.3.2 @@ -44,6 +45,7 @@ MarkupSafe==2.1.4 ml-dtypes==0.2.0 mpmath==1.3.0 msgpack==1.0.7 +mypy==1.11.2 mypy-extensions==1.0.0 networkx==3.2.1 numpy==1.26.3 @@ -92,6 +94,7 @@ pytz==2023.4 PyYAML==6.0.1 ray==2.9.3 record-keeper==0.9.32 +redis==5.0.3 referencing==0.33.0 requests==2.31.0 requests-oauthlib==1.3.1 @@ -122,4 +125,3 @@ virtualenv==20.25.0 Werkzeug==3.0.1 wrapt==1.14.1 zipp==3.17.0 -redis==5.0.3 From 955216ba32cbb2094ae766789c94931384861479 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 16:40:15 +0000 Subject: [PATCH 13/60] fixed a few scattered errors --- .gitignore | 1 + scripts/mypy.ini | 1 - tune/protox/embedding/vae.py | 4 ++-- tune/protox/env/types.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 383aa46f..4d6abb6e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ +.mypy_cache/ .conda/ .idea/ test_clean_scratchspace/ diff --git a/scripts/mypy.ini b/scripts/mypy.ini index e68d5fa4..b1efbe84 100644 --- a/scripts/mypy.ini +++ b/scripts/mypy.ini @@ -1,3 +1,2 @@ [mypy] strict = True -ignore_missing_imports = True \ No newline at end of file diff --git a/tune/protox/embedding/vae.py b/tune/protox/embedding/vae.py index aa5df154..eb03a3f0 100644 --- a/tune/protox/embedding/vae.py +++ b/tune/protox/embedding/vae.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from pytorch_metric_learning import losses, reducers # type: ignore -from pytorch_metric_learning.utils import common_functions as c_f # type: ignore +from pytorch_metric_learning import losses, reducers +from pytorch_metric_learning.utils import common_functions as c_f def gen_vae_collate( diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index 17c963a1..6b7758f7 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -56,10 +56,10 @@ class ServerIndexMetadata(TypedDict, total=False): ServerTableIndexMetadata = NewType( "ServerTableIndexMetadata", dict[str, dict[str, ServerIndexMetadata]] ) -ProtoAction = NewType("ProtoAction", torch.Tensor) # type: ignore +ProtoAction = NewType("ProtoAction", torch.Tensor) KnobMap = NewType("KnobMap", dict[str, Union[Knob, CategoricalKnob]]) -KnobSpaceRawAction = NewType("KnobSpaceRawAction", torch.Tensor) # type: ignore +KnobSpaceRawAction = NewType("KnobSpaceRawAction", torch.Tensor) # {knob.name(): knob_value, ...} KnobSpaceAction = NewType("KnobSpaceAction", dict[str, Any]) # {knob.name(): knob_value, ...} From e054e828a619ee8b460f0cb7f1e2664158aaf1ea Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 16:47:27 +0000 Subject: [PATCH 14/60] fixed tune/protox/embedding/train_args.py --- tune/protox/embedding/train.py | 45 +++++++++++++++++------ tune/protox/embedding/train_args.py | 56 +++++++++++++++++------------ 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/tune/protox/embedding/train.py b/tune/protox/embedding/train.py index 69eba251..7aed79d4 100644 --- a/tune/protox/embedding/train.py +++ b/tune/protox/embedding/train.py @@ -1,6 +1,7 @@ import logging import random from pathlib import Path +from typing import Optional import click import numpy as np @@ -59,75 +60,90 @@ ) @click.option( "--scale-factor", + type=float, default=1.0, help=f"The scale factor used when generating the data of the benchmark.", ) @click.option( "--benchmark-config-path", - default=None, type=Path, + default=None, help=f"The path to the .yaml config file for the benchmark. The default is {default_benchmark_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--traindata-path", - default=None, type=Path, + default=None, help=f"The path to the .parquet file containing the training data to use to train the embedding models. The default is {default_traindata_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.", ) @click.option( "--seed", - default=None, type=int, + default=None, help="The seed used for all sources of randomness (random, np, torch, etc.). The default is a random value.", ) # train args @click.option( "--hpo-space-path", + type=Path, default=DEFAULT_HPO_SPACE_PATH, - type=str, help="The path to the .json file defining the search space for hyperparameter optimization (HPO).", ) @click.option( "--train-max-concurrent", - default=1, type=int, + default=1, help="The max # of concurrent embedding models to train during hyperparameter optimization. This is usually set lower than `nproc` to reduce memory pressure.", ) @click.option("--iterations-per-epoch", default=1000, help=f"TODO(wz2)") @click.option( "--num-samples", + type=int, default=40, help=f"The # of times to specific hyperparameter configs to sample from the hyperparameter search space and train embedding models with.", ) -@click.option("--train-size", default=0.99, help=f"TODO(wz2)") +@click.option( + "--train-size", + type=float, + default=0.99, + help=f"TODO(wz2)" +) # analyze args @click.option( - "--start-epoch", default=0, help="The epoch to start analyzing models at." + "--start-epoch", + type=int, + default=0, + help="The epoch to start analyzing models at." ) @click.option( "--batch-size", + type=int, default=8192, help=f"The size of batches to use to build {STATS_FNAME}.", ) @click.option( "--num-batches", + type=int, default=100, help=f'The number of batches to use to build {STATS_FNAME}. Setting it to -1 indicates "use all batches".', ) @click.option( "--max-segments", + type=int, default=15, help=f"The maximum # of segments in the latent space when creating {RANGES_FNAME}.", ) @click.option( "--num-points-to-sample", + type=int, default=8192, help=f"The number of points to sample when creating {RANGES_FNAME}.", ) @click.option( "--num-classes-to-keep", + type=int, default=5, help=f"The number of classes to keep for each segment when creating {RANGES_FNAME}.", ) @@ -158,12 +174,21 @@ help="The number of indexes whose errors to compute during _attach().", ) @click.option( - "--num-curate", default=1, help="The number of models to curate" + "--num-curate", + type=int, + default=1, + help="The number of models to curate" ) # TODO(wz2): why would we want to curate more than one? @click.option( - "--allow-all", is_flag=True, help="Whether to curate within or across parts." + "--allow-all", + is_flag=True, + help="Whether to curate within or across parts." +) +@click.option("--flatten-idx", + type=int, + default=0, + help="TODO(wz2)" ) -@click.option("--flatten-idx", default=0, help="TODO(wz2)") def train( dbgym_cfg, benchmark_name, diff --git a/tune/protox/embedding/train_args.py b/tune/protox/embedding/train_args.py index f4a955f9..21b2917a 100644 --- a/tune/protox/embedding/train_args.py +++ b/tune/protox/embedding/train_args.py @@ -1,16 +1,19 @@ +from pathlib import Path + + class EmbeddingTrainGenericArgs: """Same comment as EmbeddingDatagenGenericArgs""" def __init__( self, - benchmark_name, - workload_name, - scale_factor, - benchmark_config_path, - traindata_path, - seed, - workload_path, - ): + benchmark_name: str, + workload_name: str, + scale_factor: float, + benchmark_config_path: Path, + traindata_path: Path, + seed: int, + workload_path: Path, + ) -> None: self.benchmark_name = benchmark_name self.workload_name = workload_name self.scale_factor = scale_factor @@ -25,12 +28,12 @@ class EmbeddingTrainAllArgs: def __init__( self, - hpo_space_path, - train_max_concurrent, - iterations_per_epoch, - num_samples, - train_size, - ): + hpo_space_path: Path, + train_max_concurrent: int, + iterations_per_epoch: int, + num_samples: int, + train_size: float, + ) -> None: self.hpo_space_path = hpo_space_path self.train_max_concurrent = train_max_concurrent self.iterations_per_epoch = iterations_per_epoch @@ -43,13 +46,13 @@ class EmbeddingAnalyzeArgs: def __init__( self, - start_epoch, - batch_size, - num_batches, - max_segments, - num_points_to_sample, - num_classes_to_keep, - ): + start_epoch: int, + batch_size: int, + num_batches: int, + max_segments: int, + num_points_to_sample: int, + num_classes_to_keep: int, + ) -> None: self.start_epoch = start_epoch self.batch_size = batch_size self.num_batches = num_batches @@ -62,8 +65,15 @@ class EmbeddingSelectArgs: """Same comment as EmbeddingDatagenGenericArgs""" def __init__( - self, recon, latent_dim, bias_sep, idx_limit, num_curate, allow_all, flatten_idx - ): + self, + recon: float, + latent_dim: int, + bias_sep: float, + idx_limit: int, + num_curate: int, + allow_all: bool, + flatten_idx: int + ) -> None: self.recon = recon self.latent_dim = latent_dim self.bias_sep = bias_sep From 0c816123689d59c4f3b63674e6c468f1eec19369 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 16:49:10 +0000 Subject: [PATCH 15/60] fixed scripts/read_parquet.py --- dependencies/requirements.txt | 2 ++ scripts/read_parquet.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dependencies/requirements.txt b/dependencies/requirements.txt index 1be1b2b0..df6723a0 100644 --- a/dependencies/requirements.txt +++ b/dependencies/requirements.txt @@ -76,6 +76,7 @@ oauthlib==3.2.2 opt-einsum==3.3.0 packaging==23.2 pandas==2.2.0 +pandas-stubs==2.2.2.240807 pathspec==0.12.1 pglast==6.2 platformdirs==4.2.0 @@ -118,6 +119,7 @@ tomli==2.0.1 torch==2.0.0 tqdm==4.66.1 triton==2.0.0 +types-pytz==2024.1.0.20240417 typing_extensions==4.9.0 tzdata==2023.4 urllib3==2.2.0 diff --git a/scripts/read_parquet.py b/scripts/read_parquet.py index 161aec35..36b7cd35 100644 --- a/scripts/read_parquet.py +++ b/scripts/read_parquet.py @@ -1,9 +1,10 @@ +from pathlib import Path import sys import pandas as pd -def read_and_print_parquet(file_path): +def read_and_print_parquet(file_path: Path) -> None: # Read the Parquet file into a DataFrame df = pd.read_parquet(file_path) @@ -14,7 +15,7 @@ def read_and_print_parquet(file_path): if __name__ == "__main__": # Specify the path to the Parquet file - parquet_file_path = sys.argv[0] + parquet_file_path = Path(sys.argv[0]) # Call the function to read and print the Parquet file read_and_print_parquet(parquet_file_path) From 8fc0af5f3a93bc8d9d00980a9e4cf3c51fa61899 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 16:53:33 +0000 Subject: [PATCH 16/60] fixed util/shell.py --- util/shell.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/util/shell.py b/util/shell.py index ab06f4c3..29a03aff 100644 --- a/util/shell.py +++ b/util/shell.py @@ -1,18 +1,16 @@ import logging import os +from pathlib import Path import subprocess +from typing import Optional shell_util_logger = logging.getLogger("shell_util") shell_util_logger.setLevel(logging.INFO) -def subprocess_run(c, cwd=None, check_returncode=True, dry_run=False, verbose=True): +def subprocess_run(c: str, cwd: Optional[Path]=None, check_returncode: bool=True, verbose: bool=True) -> subprocess.Popen[str]: cwd_msg = f"(cwd: {cwd if cwd is not None else os.getcwd()})" - if dry_run: - shell_util_logger.info(f"Dry run {cwd_msg}: {c}") - return - if verbose: shell_util_logger.info(f"Running {cwd_msg}: {c}") @@ -27,6 +25,7 @@ def subprocess_run(c, cwd=None, check_returncode=True, dry_run=False, verbose=Tr ) as proc: while True: loop = proc.poll() is None + assert proc.stdout is not None for line in proc.stdout: if verbose: print(line, end="", flush=True) From b626bb47086e3c8b757a58deb414574240c96428 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 16:55:01 +0000 Subject: [PATCH 17/60] fixed tune/protox/env/space/primitive/index.py --- tune/protox/env/space/primitive/index.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tune/protox/env/space/primitive/index.py b/tune/protox/env/space/primitive/index.py index ae31a486..070bf092 100644 --- a/tune/protox/env/space/primitive/index.py +++ b/tune/protox/env/space/primitive/index.py @@ -7,7 +7,7 @@ class IndexAction(object): IA = TypeVar("IA", bound="IndexAction") index_name_counter = 0 - index_name_map: dict["IndexAction", int] = dict() + index_name_map: dict["IndexAction", str] = dict() def __init__( self, @@ -81,7 +81,7 @@ def sql(self, add: bool, allow_fail: bool = False) -> str: # A given index name (like "index5") maps one-to-one to the function of an # index (i.e. its table, columns, etc.). - def get_index_name(self): + def get_index_name(self) -> str: if self not in IndexAction.index_name_map: IndexAction.index_name_map[self] = f"index{IndexAction.index_name_counter}" IndexAction.index_name_counter += 1 From b0fbba73ad84aad9b3e655323e54cdca09cfb087 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 17:00:17 +0000 Subject: [PATCH 18/60] fixed tune/protox/embedding/vae.py --- dependencies/requirements.txt | 8 ++++---- scripts/mypy.ini | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dependencies/requirements.txt b/dependencies/requirements.txt index df6723a0..a5cf3858 100644 --- a/dependencies/requirements.txt +++ b/dependencies/requirements.txt @@ -58,7 +58,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu11==11.7.99 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu11==8.5.0.96 -nvidia-cudnn-cu12==8.9.2.26 +nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu11==10.9.0.58 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu11==10.2.10.91 @@ -68,7 +68,7 @@ nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu11==11.7.4.91 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu11==2.14.3 -nvidia-nccl-cu12==2.19.3 +nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu11==11.7.91 nvidia-nvtx-cu12==12.1.105 @@ -116,9 +116,9 @@ tensorflow-io-gcs-filesystem==0.36.0 termcolor==2.4.0 threadpoolctl==3.2.0 tomli==2.0.1 -torch==2.0.0 +torch==2.4.0 tqdm==4.66.1 -triton==2.0.0 +triton==3.0.0 types-pytz==2024.1.0.20240417 typing_extensions==4.9.0 tzdata==2023.4 diff --git a/scripts/mypy.ini b/scripts/mypy.ini index b1efbe84..e68d5fa4 100644 --- a/scripts/mypy.ini +++ b/scripts/mypy.ini @@ -1,2 +1,3 @@ [mypy] strict = True +ignore_missing_imports = True \ No newline at end of file From 0eeccc5b71468fce4c723c9875ffc996e2d2cac8 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 17:07:29 +0000 Subject: [PATCH 19/60] fixed misc/utils.py --- dependencies/requirements.txt | 1 + misc/utils.py | 46 +++++++++++++++++------------------ 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/dependencies/requirements.txt b/dependencies/requirements.txt index a5cf3858..208698b1 100644 --- a/dependencies/requirements.txt +++ b/dependencies/requirements.txt @@ -120,6 +120,7 @@ torch==2.4.0 tqdm==4.66.1 triton==3.0.0 types-pytz==2024.1.0.20240417 +types-PyYAML==6.0.12.20240808 typing_extensions==4.9.0 tzdata==2023.4 urllib3==2.2.0 diff --git a/misc/utils.py b/misc/utils.py index 229bc075..d68cc233 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -4,7 +4,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import Callable, Tuple, Optional +from typing import IO, Any, Callable, Tuple, Optional import redis import yaml @@ -32,15 +32,15 @@ # Helper functions that both this file and other files use -def get_symlinks_path_from_workspace_path(workspace_path): +def get_symlinks_path_from_workspace_path(workspace_path: Path) -> Path: return workspace_path / "symlinks" -def get_tmp_path_from_workspace_path(workspace_path): +def get_tmp_path_from_workspace_path(workspace_path: Path) -> Path: return workspace_path / "tmp" -def get_runs_path_from_workspace_path(workspace_path): +def get_runs_path_from_workspace_path(workspace_path: Path) -> Path: return workspace_path / "task_runs" @@ -84,7 +84,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float | str) -> str: # Standard names of files/directories. These can refer to either the actual file/directory or a link to the file/directory. # Since they can refer to either the actual or the link, they do not have ".link" in them. -traindata_fname = ( +traindata_fname: Callable[[str, str], str] = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_embedding_traindata.parquet" ) default_embedder_dname: Callable[[str, str], str] = ( @@ -110,7 +110,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float | str) -> str: # folder called run_*/dbgym_agent_protox_tune/tuning_steps. However, replay itself generates an output.log file, which goes in # run_*/dbgym_agent_protox_tune/tuning_steps/. The bug was that my replay function was overwriting the output.log file of the # tuning run. By naming all symlinks "*.link", we avoid the possibility of subtle bugs like this happening. -default_traindata_path = ( +default_traindata_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) @@ -198,7 +198,7 @@ def __init__(self, dbgym_config_path: Path): # Parse the YAML file. contents: str = dbgym_config_path.read_text() - yaml_config: dict = yaml.safe_load(contents) + yaml_config: dict[str, Any] = yaml.safe_load(contents) # Require dbgym_workspace_path to be absolute. # All future paths should be constructed from dbgym_workspace_path. @@ -208,8 +208,8 @@ def __init__(self, dbgym_config_path: Path): self.path: Path = dbgym_config_path self.cur_path_list: list[str] = ["dbgym"] - self.root_yaml: dict = yaml_config - self.cur_yaml: dict = self.root_yaml + self.root_yaml: dict[str, Any] = yaml_config + self.cur_yaml: dict[str, Any] = self.root_yaml # Set and create paths. self.dbgym_repo_path = Path(os.getcwd()) @@ -244,11 +244,11 @@ def __init__(self, dbgym_config_path: Path): # `append_group()` is used to mark the "codebase path" of an invocation of the CLI. The "codebase path" is # explained further in the documentation. - def append_group(self, name) -> None: + def append_group(self, name: str) -> None: self.cur_path_list.append(name) self.cur_yaml = self.cur_yaml.get(name, {}) - def cur_source_path(self, *dirs) -> Path: + def cur_source_path(self, *dirs: str) -> Path: cur_path = self.dbgym_repo_path assert self.cur_path_list[0] == "dbgym" for folder in self.cur_path_list[1:]: @@ -257,7 +257,7 @@ def cur_source_path(self, *dirs) -> Path: cur_path = cur_path / dir return cur_path - def cur_symlinks_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_path(self, *dirs: str, mkdir: bool=False) -> Path: flattened_structure = "_".join(self.cur_path_list) cur_path = self.dbgym_symlinks_path / flattened_structure for dir in dirs: @@ -266,7 +266,7 @@ def cur_symlinks_path(self, *dirs, mkdir=False) -> Path: cur_path.mkdir(parents=True, exist_ok=True) return cur_path - def cur_task_runs_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_path(self, *dirs: str, mkdir: bool=False) -> Path: flattened_structure = "_".join(self.cur_path_list) cur_path = self.dbgym_this_run_path / flattened_structure for dir in dirs: @@ -275,27 +275,27 @@ def cur_task_runs_path(self, *dirs, mkdir=False) -> Path: cur_path.mkdir(parents=True, exist_ok=True) return cur_path - def cur_symlinks_bin_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_bin_path(self, *dirs: str, mkdir: bool=False) -> Path: return self.cur_symlinks_path("bin", *dirs, mkdir=mkdir) - def cur_symlinks_build_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_build_path(self, *dirs: str, mkdir: bool=False) -> Path: return self.cur_symlinks_path("build", *dirs, mkdir=mkdir) - def cur_symlinks_data_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_data_path(self, *dirs: str, mkdir: bool=False) -> Path: return self.cur_symlinks_path("data", *dirs, mkdir=mkdir) - def cur_task_runs_build_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_build_path(self, *dirs: str, mkdir: bool=False) -> Path: return self.cur_task_runs_path("build", *dirs, mkdir=mkdir) - def cur_task_runs_data_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_data_path(self, *dirs: str, mkdir: bool=False) -> Path: return self.cur_task_runs_path("data", *dirs, mkdir=mkdir) - def cur_task_runs_artifacts_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_artifacts_path(self, *dirs: str, mkdir: bool=False) -> Path: return self.cur_task_runs_path("artifacts", *dirs, mkdir=mkdir) def conv_inputpath_to_realabspath( - dbgym_cfg: DBGymConfig, inputpath: os.PathLike + dbgym_cfg: DBGymConfig, inputpath: os.PathLike[str] ) -> Path: """ Convert any user inputted path to a real, absolute path @@ -326,7 +326,7 @@ def conv_inputpath_to_realabspath( return realabspath -def is_base_git_dir(cwd) -> bool: +def is_base_git_dir(cwd: str) -> bool: """ Returns whether we are in the base directory of some git repository """ @@ -391,7 +391,7 @@ def basename_of_path(dpath: Path) -> str: # TODO(phw2): refactor to use Path -def is_child_path(child_path: os.PathLike, parent_dpath: os.PathLike) -> bool: +def is_child_path(child_path: os.PathLike[str], parent_dpath: os.PathLike[str]) -> bool: """ Checks whether child_path refers to a file/dir/link that is a child of the dir referred to by parent_dpath If the two paths are equal, this function returns FALSE @@ -405,7 +405,7 @@ def is_child_path(child_path: os.PathLike, parent_dpath: os.PathLike) -> bool: ) -def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode="r"): +def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode: str="r") -> IO[Any]: """ Open a file and "save" it to [workspace]/task_runs/run_*/. It takes in a str | Path to match the interface of open(). From 921ac5d535d3f755201213811a1bb9f9c082f38a Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 17:15:09 +0000 Subject: [PATCH 20/60] fixed tune/protox/agent/hpo.py --- tune/protox/agent/hpo.py | 10 +++++----- tune/protox/embedding/utils.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tune/protox/agent/hpo.py b/tune/protox/agent/hpo.py index 8fdf860c..98736ab3 100644 --- a/tune/protox/agent/hpo.py +++ b/tune/protox/agent/hpo.py @@ -559,7 +559,7 @@ def setup(self, hpo_params: dict[str, Any]) -> None: # Attach mythril directory to the search path. sys.path.append(os.path.expanduser(self.dbgym_cfg.dbgym_repo_path)) - torch.set_default_dtype(torch.float32) + torch.set_default_dtype(torch.float32) # type: ignore[no-untyped-call] seed = ( hpo_params["seed"] if hpo_params["seed"] != -1 @@ -648,7 +648,7 @@ def step(self) -> dict[Any, Any]: def cleanup(self) -> None: self.logger.flush() - self.env.close() + self.env.close() # type: ignore[no-untyped-call] if Path(self.signal).exists(): os.remove(self.signal) @@ -663,7 +663,7 @@ def create_tune_opt_class(dbgym_cfg_param: DBGymConfig) -> Type[Trainable]: global global_dbgym_cfg global_dbgym_cfg = dbgym_cfg_param - class TuneOpt(Trainable): # type: ignore + class TuneOpt(Trainable): dbgym_cfg = global_dbgym_cfg def setup(self, hpo_params: dict[str, Any]) -> None: @@ -750,7 +750,7 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None: ) # Scheduler. - scheduler = FIFOScheduler() + scheduler = FIFOScheduler() # type: ignore[no-untyped-call] # Search. search = BasicVariantGenerator(max_concurrent=hpo_args.max_concurrent) @@ -773,7 +773,7 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None: sync_config=SyncConfig(), verbose=2, log_to_file=True, - storage_path=dbgym_cfg.cur_task_runs_path("hpo_ray_results", mkdir=True), + storage_path=str(dbgym_cfg.cur_task_runs_path("hpo_ray_results", mkdir=True)), ) tuner = ray.tune.Tuner( diff --git a/tune/protox/embedding/utils.py b/tune/protox/embedding/utils.py index a631c24f..0e369158 100644 --- a/tune/protox/embedding/utils.py +++ b/tune/protox/embedding/utils.py @@ -1,6 +1,6 @@ from typing import Any -from hyperopt import hp # type: ignore +from hyperopt import hp def f_unpack_dict(dct: dict[str, Any]) -> dict[str, Any]: From a6f22401b919f1da7f437f76b436829060dcd240 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 17:21:33 +0000 Subject: [PATCH 21/60] fixed tune/protox/agent/replay.py --- dependencies/requirements.txt | 3 ++- tune/protox/agent/agent_env.py | 4 +++- tune/protox/agent/replay.py | 6 +++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dependencies/requirements.txt b/dependencies/requirements.txt index 208698b1..2bf61344 100644 --- a/dependencies/requirements.txt +++ b/dependencies/requirements.txt @@ -26,7 +26,7 @@ google-auth-oauthlib==1.0.0 google-pasta==0.2.0 greenlet==3.0.3 grpcio==1.60.0 -gymnasium==0.28.1 +gymnasium==0.29.1 h5py==3.10.0 hyperopt==0.2.7 idna==3.6 @@ -119,6 +119,7 @@ tomli==2.0.1 torch==2.4.0 tqdm==4.66.1 triton==3.0.0 +types-python-dateutil==2.9.0.20240821 types-pytz==2024.1.0.20240417 types-PyYAML==6.0.12.20240808 typing_extensions==4.9.0 diff --git a/tune/protox/agent/agent_env.py b/tune/protox/agent/agent_env.py index b5af657b..98e311b9 100644 --- a/tune/protox/agent/agent_env.py +++ b/tune/protox/agent/agent_env.py @@ -6,8 +6,10 @@ import numpy as np from numpy.typing import NDArray +from tune.protox.env.pg_env import PostgresEnv -class AgentEnv(gym.Wrapper[Any, Any, Any, Any]): + +class AgentEnv(gym.Wrapper[PostgresEnv, Any, Any, Any]): def __init__(self, env: gym.Env[Any, Any]): super().__init__(env) self.class_attributes = dict(inspect.getmembers(self.__class__)) diff --git a/tune/protox/agent/replay.py b/tune/protox/agent/replay.py index cca69911..83f0a89b 100644 --- a/tune/protox/agent/replay.py +++ b/tune/protox/agent/replay.py @@ -11,7 +11,7 @@ import logging import pickle from pathlib import Path -from typing import Any, Optional, Set +from typing import Any, Optional, Set, cast import click import pandas as pd @@ -233,8 +233,8 @@ def _is_tuning_step_line(line: str) -> bool: _, _, agent_env, _, _ = build_trial( dbgym_cfg, TuningMode.REPLAY, hpo_params["seed"], hpo_params ) - pg_env: PostgresEnv = agent_env.unwrapped - action_space: HolonSpace = pg_env.action_space + pg_env: PostgresEnv = cast(PostgresEnv, agent_env.unwrapped) + action_space: HolonSpace = cast(HolonSpace, pg_env.action_space) # Reset things. if not replay_args.simulated: From 64ddbd935b357208c9551964c1fa07ee6c18f8ac Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 17:27:10 +0000 Subject: [PATCH 22/60] fixed tune/protox/agent/build_trial.py --- tune/protox/agent/agent_env.py | 4 +--- tune/protox/agent/build_trial.py | 11 +++++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tune/protox/agent/agent_env.py b/tune/protox/agent/agent_env.py index 98e311b9..b5af657b 100644 --- a/tune/protox/agent/agent_env.py +++ b/tune/protox/agent/agent_env.py @@ -6,10 +6,8 @@ import numpy as np from numpy.typing import NDArray -from tune.protox.env.pg_env import PostgresEnv - -class AgentEnv(gym.Wrapper[PostgresEnv, Any, Any, Any]): +class AgentEnv(gym.Wrapper[Any, Any, Any, Any]): def __init__(self, env: gym.Env[Any, Any]): super().__init__(env) self.class_attributes = dict(inspect.getmembers(self.__class__)) diff --git a/tune/protox/agent/build_trial.py b/tune/protox/agent/build_trial.py index 4bd23ae5..dde38718 100644 --- a/tune/protox/agent/build_trial.py +++ b/tune/protox/agent/build_trial.py @@ -1,7 +1,5 @@ import glob import json -import os -import shutil import socket import xml.etree.ElementTree as ET from pathlib import Path @@ -11,8 +9,9 @@ import numpy as np import torch from gymnasium.wrappers import FlattenObservation # type: ignore -from gymnasium.wrappers import NormalizeObservation, NormalizeReward +from gymnasium.wrappers import NormalizeObservation, NormalizeReward # type: ignore[attr-defined] from torch import nn +from torch.optim import Adam # type: ignore[attr-defined] from misc.utils import ( DBGymConfig, @@ -85,7 +84,7 @@ def _get_signal(signal_folder: Union[str, Path]) -> Tuple[int, str]: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) continue - with open(f"{signal_folder}/{port}.signal", "w") as f: # type: IO[Any] + with open(f"{signal_folder}/{port}.signal", "w") as f: f.write(str(port)) f.close() @@ -434,7 +433,7 @@ def _build_agent( policy_weight_adjustment=hpo_params["policy_weight_adjustment"], ) - actor_optimizer = torch.optim.Adam( + actor_optimizer = Adam( actor.parameters(), lr=hpo_params["learning_rate"] ) @@ -462,7 +461,7 @@ def _build_agent( action_dim=critic_action_dim, ) - critic_optimizer = torch.optim.Adam( + critic_optimizer = Adam( critic.parameters(), lr=hpo_params["learning_rate"] * hpo_params["critic_lr_scale"], ) From d9d69966998c9ae2f68e676a1cb4e65aa2d29a21 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 17:32:25 +0000 Subject: [PATCH 23/60] fixed tune/protox/embedding/cli.py --- tune/protox/embedding/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tune/protox/embedding/cli.py b/tune/protox/embedding/cli.py index 264afed9..0e8829b9 100644 --- a/tune/protox/embedding/cli.py +++ b/tune/protox/embedding/cli.py @@ -7,7 +7,7 @@ @click.group("embedding") @click.pass_obj -def embedding_group(dbgym_cfg: DBGymConfig): +def embedding_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("embedding") From 70c2d198a85798f40b753a72861d59ec29a330f0 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 17:36:21 +0000 Subject: [PATCH 24/60] fixed tune/protox/embedding/train.py --- dbms/postgres/cli.py | 4 +- tune/protox/agent/hpo.py | 2 +- tune/protox/agent/tune.py | 4 +- tune/protox/embedding/datagen.py | 18 ++++---- tune/protox/embedding/train.py | 70 +++++++++++++++++--------------- 5 files changed, 51 insertions(+), 47 deletions(-) diff --git a/dbms/postgres/cli.py b/dbms/postgres/cli.py index 140f7e7c..6dd6e40e 100644 --- a/dbms/postgres/cli.py +++ b/dbms/postgres/cli.py @@ -99,9 +99,9 @@ def postgres_dbdata( dbdata_parent_dpath: Path, ): # Set args to defaults programmatically (do this before doing anything else in the function) - if pgbin_path == None: + if pgbin_path is None: pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path) - if dbdata_parent_dpath == None: + if dbdata_parent_dpath is None: dbdata_parent_dpath = default_dbdata_parent_dpath( dbgym_cfg.dbgym_workspace_path ) diff --git a/tune/protox/agent/hpo.py b/tune/protox/agent/hpo.py index 98736ab3..251d526e 100644 --- a/tune/protox/agent/hpo.py +++ b/tune/protox/agent/hpo.py @@ -551,7 +551,7 @@ def __init__( ), "If we're doing HPO, we will create multiple TuneTrial() objects. We thus need to differentiate them somehow." else: assert ( - ray_trial_id == None + ray_trial_id is None ), "If we're not doing HPO, we (currently) will create only one TuneTrial() object. For clarity, we set ray_trial_id to None since ray_trial_id should not be used in this case." self.ray_trial_id = ray_trial_id diff --git a/tune/protox/agent/tune.py b/tune/protox/agent/tune.py index 2ec6045b..c9a3467c 100644 --- a/tune/protox/agent/tune.py +++ b/tune/protox/agent/tune.py @@ -89,7 +89,7 @@ def tune( """IMPORTANT: The "tune" here is the one in "tune a DBMS". This is *different* from the "tune" in ray.tune.TuneConfig, which means to "tune hyperparameters".""" # Set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if hpoed_agent_params_path == None: + if hpoed_agent_params_path is None: hpoed_agent_params_path = default_hpoed_agent_params_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) @@ -120,7 +120,7 @@ def tune( ) # Set defaults that depend on hpo_params - if tune_duration_during_tune == None: + if tune_duration_during_tune is None: tune_duration_during_tune = hpo_params["tune_duration"][str(TuningMode.HPO)] # Set the hpo_params that are allowed to differ between HPO, tuning, and replay. diff --git a/tune/protox/embedding/datagen.py b/tune/protox/embedding/datagen.py index 53defc2b..c415d9fc 100644 --- a/tune/protox/embedding/datagen.py +++ b/tune/protox/embedding/datagen.py @@ -201,25 +201,25 @@ def datagen( # TODO(phw2): figure out whether different scale factors use the same config # TODO(phw2): figure out what parts of the config should be taken out (like stuff about tables) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if benchmark_config_path == None: + if benchmark_config_path is None: benchmark_config_path = default_benchmark_config_path(benchmark_name) - if workload_path == None: + if workload_path is None: workload_path = default_workload_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) - if pgbin_path == None: + if pgbin_path is None: pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path) - if pristine_dbdata_snapshot_path == None: + if pristine_dbdata_snapshot_path is None: pristine_dbdata_snapshot_path = default_pristine_dbdata_snapshot_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, scale_factor ) - if dbdata_parent_dpath == None: + if dbdata_parent_dpath is None: dbdata_parent_dpath = default_dbdata_parent_dpath( dbgym_cfg.dbgym_workspace_path ) - if max_concurrent == None: + if max_concurrent is None: max_concurrent = os.cpu_count() - if seed == None: + if seed is None: seed = random.randint(0, 1e8) # Convert all input paths to absolute paths @@ -246,10 +246,10 @@ def datagen( assert False # Process the "data structure" args - leading_col_tbls = [] if leading_col_tbls == None else leading_col_tbls.split(",") + leading_col_tbls = [] if leading_col_tbls is None else leading_col_tbls.split(",") # I chose to only use the "," delimiter in override_sample_limits_str, so the dictionary is encoded as [key],[value],[key],[value] # I felt this was better than introducing a new delimiter which might conflict with the name of a table - if override_sample_limits == None: + if override_sample_limits is None: override_sample_limits = dict() else: override_sample_limits_str = override_sample_limits diff --git a/tune/protox/embedding/train.py b/tune/protox/embedding/train.py index 7aed79d4..67609f56 100644 --- a/tune/protox/embedding/train.py +++ b/tune/protox/embedding/train.py @@ -12,6 +12,7 @@ DEFAULT_HPO_SPACE_PATH, WORKLOAD_NAME_PLACEHOLDER, WORKSPACE_PATH_PLACEHOLDER, + DBGymConfig, conv_inputpath_to_realabspath, default_benchmark_config_path, default_traindata_path, @@ -40,7 +41,10 @@ @click.pass_obj # generic args -@click.argument("benchmark-name", type=str) +@click.argument( + "benchmark-name", + type=str +) @click.option( "--seed-start", type=int, @@ -190,34 +194,34 @@ help="TODO(wz2)" ) def train( - dbgym_cfg, - benchmark_name, - seed_start, - seed_end, - query_subset, - scale_factor, - benchmark_config_path, - traindata_path, - seed, - hpo_space_path, - train_max_concurrent, - iterations_per_epoch, - num_samples, - train_size, - start_epoch, - batch_size, - num_batches, - max_segments, - num_points_to_sample, - num_classes_to_keep, - recon, - latent_dim, - bias_sep, - idx_limit, - num_curate, - allow_all, - flatten_idx, -): + dbgym_cfg: DBGymConfig, + benchmark_name: str, + seed_start: int, + seed_end: int, + query_subset: str, + scale_factor: float, + benchmark_config_path: Optional[Path], + traindata_path: Optional[Path], + seed: Optional[int], + hpo_space_path: Path, + train_max_concurrent: int, + iterations_per_epoch: int, + num_samples: int, + train_size: int, + start_epoch: int, + batch_size: int, + num_batches: int, + max_segments: int, + num_points_to_sample: int, + num_classes_to_keep: int, + recon: float, + latent_dim: int, + bias_sep: float, + idx_limit: int, + num_curate: int, + allow_all: bool, + flatten_idx: int, +) -> None: """ Trains embeddings with num_samples samples of the hyperparameter space. Analyzes the accuracy of all epochs of all hyperparameter space samples. @@ -225,16 +229,16 @@ def train( """ # set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if traindata_path == None: + if traindata_path is None: traindata_path = default_traindata_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) # TODO(phw2): figure out whether different scale factors use the same config # TODO(phw2): figure out what parts of the config should be taken out (like stuff about tables) - if benchmark_config_path == None: + if benchmark_config_path is None: benchmark_config_path = default_benchmark_config_path(benchmark_name) - if seed == None: - seed = random.randint(0, 1e8) + if seed is None: + seed = random.randint(0, int(1e8)) # Convert all input paths to absolute paths benchmark_config_path = conv_inputpath_to_realabspath( From 5f52029a1a9a7de3edb2297d4a95843ed5312ec7 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 18:16:00 +0000 Subject: [PATCH 25/60] now ignoring errors in embedding/ --- scripts/mypy.ini | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/mypy.ini b/scripts/mypy.ini index e68d5fa4..e4c360df 100644 --- a/scripts/mypy.ini +++ b/scripts/mypy.ini @@ -1,3 +1,6 @@ [mypy] strict = True -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True + +[mypy-tune.protox.embedding.*] +ignore_errors = True \ No newline at end of file From 62280888b391e00c488c2b489859a932da2381ab Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:06:22 +0000 Subject: [PATCH 26/60] fixed util/pg.py --- util/pg.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/util/pg.py b/util/pg.py index 8c5f1e78..9e08f07e 100644 --- a/util/pg.py +++ b/util/pg.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List +from typing import Any, List import pglast import psycopg @@ -16,7 +16,7 @@ SHARED_PRELOAD_LIBRARIES = "boot,pg_hint_plan,pg_prewarm" -def conn_execute(conn: Connection, sql: str) -> CursorResult: +def conn_execute(conn: Connection, sql: str) -> CursorResult[Any]: return conn.execute(text(sql)) @@ -29,8 +29,9 @@ def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> List[str]: if len(line.strip()) == 0: continue lines.append(line) - queries = "".join(lines) - return pglast.split(queries) + queries_str = "".join(lines) + queries: list[str] = pglast.split(queries_str) + return queries def sql_file_execute(dbgym_cfg: DBGymConfig, conn: Connection, filepath: Path) -> None: @@ -40,7 +41,7 @@ def sql_file_execute(dbgym_cfg: DBGymConfig, conn: Connection, filepath: Path) - # The reason pgport is an argument is because when doing agnet HPO, we want to run multiple instances of Postgres # at the same time. In this situation, they need to have different ports -def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg=True) -> str: +def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool=True) -> str: connstr_suffix = f"{DBGYM_POSTGRES_USER}:{DBGYM_POSTGRES_PASS}@localhost:{pgport}/{DBGYM_POSTGRES_DBNAME}" # use_psycopg means whether or not we use the psycopg.connect() function # counterintuively, you *don't* need psycopg in the connection string if you *are* @@ -49,12 +50,14 @@ def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg=True) -> str: return connstr_prefix + "://" + connstr_suffix -def create_conn(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg=True) -> Connection: +def create_conn(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool=True) -> Connection: connstr = get_connstr(use_psycopg=use_psycopg, pgport=pgport) if use_psycopg: - return psycopg.connect(connstr, autocommit=True, prepare_threshold=None) + psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None) + engine = create_engine(connstr, creator=lambda : psycopg_conn) + return engine.connect() else: - engine: Engine = create_engine( + engine = create_engine( connstr, execution_options={"isolation_level": "AUTOCOMMIT"}, ) From 96338bc5b5883886b0ea9463e291fc53070bc979 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:10:48 +0000 Subject: [PATCH 27/60] fixed tune/protox/env/mqo/mqo_wrapper.py --- tune/protox/env/mqo/mqo_wrapper.py | 6 +++++- tune/protox/env/types.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tune/protox/env/mqo/mqo_wrapper.py b/tune/protox/env/mqo/mqo_wrapper.py index 84baa36f..acd5456f 100644 --- a/tune/protox/env/mqo/mqo_wrapper.py +++ b/tune/protox/env/mqo/mqo_wrapper.py @@ -163,7 +163,7 @@ def __init__( self.logger = logger def _update_best_observed( - self, query_metric_data: dict[str, BestQueryRun], force_overwrite=False + self, query_metric_data: dict[str, BestQueryRun], force_overwrite: bool=False ) -> None: if query_metric_data is not None: for qid, best_run in query_metric_data.items(): @@ -176,6 +176,7 @@ def _update_best_observed( None, ) if self.logger: + assert best_run.runtime is not None self.logger.get_logger(__name__).debug( f"[best_observe] {qid}: {best_run.runtime/1e6} (force: {force_overwrite})" ) @@ -307,6 +308,7 @@ def transmute( ) # Execute. + assert self.logger is not None self.logger.get_logger(__name__).info("MQOWrapper called step_execute()") success, info = self.unwrapped.step_execute(success, runs, info) if info["query_metric_data"]: @@ -319,6 +321,7 @@ def transmute( with torch.no_grad(): # Pass the mutilated action back through. assert isinstance(self.action_space, HolonSpace) + assert info["actions_info"] is not None info["actions_info"][ "best_observed_holon_action" ] = best_observed_holon_action @@ -412,6 +415,7 @@ def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, EnvInfoDict]: # type: # Update the reward baseline. if self.unwrapped.reward_utility: + assert self.unwrapped.baseline_metric self.unwrapped.reward_utility.set_relative_baseline( self.unwrapped.baseline_metric, prev_result=metric, diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index 6b7758f7..b61aaa9b 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -190,9 +190,9 @@ class EnvInfoDict(TypedDict, total=False): attempted_changes: Tuple[list[str], list[str]] # Metric of this step. - metric: float + metric: Optional[float] # Reward of this step. - reward: float + reward: Optional[float] # Whether any queries timed out or the workload as a whole timed out. did_anything_time_out: bool # Query metric data. From 6996e2c26e50e6eb1f4fc06c73bf36a34b6397bd Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:16:15 +0000 Subject: [PATCH 28/60] fixed tune/protox/env/pg_env.py --- tune/protox/env/pg_env.py | 9 +++++++-- tune/protox/env/types.py | 2 +- tune/protox/env/util/pg_conn.py | 10 +++++----- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tune/protox/env/pg_env.py b/tune/protox/env/pg_env.py index 3e267d53..fe6c5ba1 100644 --- a/tune/protox/env/pg_env.py +++ b/tune/protox/env/pg_env.py @@ -14,6 +14,7 @@ from tune.protox.env.space.state.space import StateSpace from tune.protox.env.space.utils import fetch_server_indexes, fetch_server_knobs from tune.protox.env.types import ( + ActionsInfo, EnvInfoDict, HolonAction, HolonStateContainer, @@ -255,6 +256,7 @@ def step_execute( assert isinstance(self.observation_space, StateSpace) assert isinstance(self.action_space, HolonSpace) # Evaluate the benchmark. + assert self.logger is not None self.logger.get_logger(__name__).info( f"\n\nfetch_server_knobs(): {fetch_server_knobs(self.pg_conn.conn(), self.action_space.get_knob_space().tables, self.action_space.get_knob_space().knobs, self.workload.queries)}\n\n" ) @@ -302,9 +304,10 @@ def step_execute( "query_metric_data": query_metric_data, "reward": reward, "results_dpath": results_dpath, - "actions_info": { + "actions_info": ActionsInfo({ "all_holon_action_variations": all_holon_action_variations, - }, + "best_observed_holon_action": None + }), } ) ) @@ -328,6 +331,7 @@ def step_post_execute( if not soft: if not self.workload.oltp_workload: # Update the workload metric timeout if we've succeeded. + assert info["metric"] is not None self.workload.set_workload_timeout(info["metric"]) # Get the current view of the state container. @@ -351,6 +355,7 @@ def step_post_execute( if not soft: self.current_step = self.current_step + 1 self.current_state = next_state + assert info["reward"] is not None return ( self.current_state, info["reward"], diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index b61aaa9b..f6821ed4 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -168,7 +168,7 @@ class QuerySpec(TypedDict, total=False): class ActionsInfo(TypedDict): all_holon_action_variations: list[Tuple[str, HolonAction]] - best_observed_holon_action: HolonAction + best_observed_holon_action: Optional[HolonAction] class EnvInfoDict(TypedDict, total=False): diff --git a/tune/protox/env/util/pg_conn.py b/tune/protox/env/util/pg_conn.py index 94da1ad3..77616ae9 100644 --- a/tune/protox/env/util/pg_conn.py +++ b/tune/protox/env/util/pg_conn.py @@ -71,7 +71,7 @@ def __init__( self._conn: Optional[psycopg.Connection[Any]] = None - def get_connstr(self): + def get_connstr(self) -> str: return f"host=localhost port={self.pgport} user={DBGYM_POSTGRES_USER} password={DBGYM_POSTGRES_PASS} dbname={DBGYM_POSTGRES_DBNAME}" def conn(self) -> psycopg.Connection[Any]: @@ -358,11 +358,11 @@ def cancel_fn(conn_str: str) -> None: self.disconnect() return 0, None - def restore_pristine_snapshot(self) -> None: - self._restore_snapshot(self.pristine_dbdata_snapshot_fpath) + def restore_pristine_snapshot(self) -> bool: + return self._restore_snapshot(self.pristine_dbdata_snapshot_fpath) - def restore_checkpointed_snapshot(self): - self._restore_snapshot(self.checkpoint_dbdata_snapshot_fpath) + def restore_checkpointed_snapshot(self) -> bool: + return self._restore_snapshot(self.checkpoint_dbdata_snapshot_fpath) @time_record("restore") def _restore_snapshot( From 20d83bf3ae8d4524a5ee186de5fd43256b6090a9 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:17:11 +0000 Subject: [PATCH 29/60] fixed tune/protox/agent/replay.py --- tune/protox/agent/replay.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tune/protox/agent/replay.py b/tune/protox/agent/replay.py index 83f0a89b..c1284569 100644 --- a/tune/protox/agent/replay.py +++ b/tune/protox/agent/replay.py @@ -275,6 +275,7 @@ def _execute_workload_wrapper(actions_info: ActionsInfo) -> tuple[int, int, bool # will not have had a chance to run at all. Based on the behavior of `_mutilate_action_with_metrics()`, we select # an arbitrary variation fo the queries that have not executed at all. best_observed_holon_action = actions_info["best_observed_holon_action"] + assert best_observed_holon_action is not None actions = [best_observed_holon_action] variation_names = ["BestObserved"] From 8c59fded8e7db1cb741ab7a9da8f5fdea084abf5 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:28:09 +0000 Subject: [PATCH 30/60] fixed tune/protox/tests/test_index_space.py --- tune/protox/env/workload.py | 7 ++++--- tune/protox/tests/test_index_space.py | 21 +++++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tune/protox/env/workload.py b/tune/protox/env/workload.py index f56b931b..3e64448b 100644 --- a/tune/protox/env/workload.py +++ b/tune/protox/env/workload.py @@ -57,7 +57,7 @@ def _open_for_reading( # why we assert here # I still chose to make mode an argument just to make the interface identical to open()/open_and_save() assert mode == "r" - if self.dbgym_cfg != None: + if self.dbgym_cfg is not None: return open_and_save(self.dbgym_cfg, path) else: return open(path) @@ -162,7 +162,7 @@ def _crunch( ) if do_tbl_include_subsets_prune: - self.tbl_include_subsets = {} + self.tbl_include_subsets = TableAttrAccessSetsMap({}) # First prune any "fully enclosed". for tbl, attrsets in tbl_include_subsets.items(): self.tbl_include_subsets[tbl] = set( @@ -217,7 +217,8 @@ def _crunch( def __init__( self, - dbgym_cfg: DBGymConfig, + # dbgym_cfg is only optional so we can set it to None for unittests. Don't set it to None during normal operation. + dbgym_cfg: Optional[DBGymConfig], tables: list[str], attributes: TableAttrListMap, query_spec: QuerySpec, diff --git a/tune/protox/tests/test_index_space.py b/tune/protox/tests/test_index_space.py index 02225649..977e6764 100644 --- a/tune/protox/tests/test_index_space.py +++ b/tune/protox/tests/test_index_space.py @@ -6,18 +6,19 @@ from tune.protox.env.space.primitive_space import IndexSpace from tune.protox.env.space.utils import check_subspace +from tune.protox.env.types import IndexSpaceRawSample from tune.protox.env.workload import Workload class IndexSpaceTests(unittest.TestCase): @staticmethod def load( - config_path=Path( + config_path: Path=Path( "tune/protox/tests/unittest_benchmark_configs/unittest_tpch.yaml" ).resolve(), - aux_type=True, - aux_include=True, - ): + aux_type: bool=True, + aux_include: bool=True, + ) -> tuple[Workload, IndexSpace]: # don't call open_and_save() because this is a unittest with open(config_path, "r") as f: benchmark_config = yaml.safe_load(f) @@ -51,7 +52,7 @@ def load( ) return w, i - def test_null_action(self): + def test_null_action(self) -> None: w, i = IndexSpaceTests.load() null_action = i.null_action() self.assertTrue(check_subspace(i, null_action)) @@ -60,19 +61,19 @@ def test_null_action(self): null_action = i.null_action() self.assertTrue(check_subspace(i, null_action)) - def test_sample(self): + def test_sample(self) -> None: w, i = IndexSpaceTests.load(aux_type=False, aux_include=False) for _ in range(100): self.assertTrue(check_subspace(i, i.sample())) - def test_sample_table(self): + def test_sample_table(self) -> None: w, i = IndexSpaceTests.load(aux_type=False, aux_include=False) for _ in range(100): mask = {"table_idx": 2} ia = i.to_action(i.sample(mask)) self.assertEqual(ia.tbl_name, "lineitem") - def test_sample_table_col(self): + def test_sample_table_col(self) -> None: w, i = IndexSpaceTests.load(aux_type=False, aux_include=False) for _ in range(100): mask = {"table_idx": 2, "col_idx": 1} @@ -80,12 +81,12 @@ def test_sample_table_col(self): self.assertEqual(ia.tbl_name, "lineitem") self.assertEqual(ia.columns[0], "l_partkey") - def test_neighborhood(self): + def test_neighborhood(self) -> None: w, i = IndexSpaceTests.load(aux_type=True, aux_include=True) _, isa = IndexSpaceTests.load(aux_type=False, aux_include=False) act = isa.sample(mask={"table_idx": 2, "col_idx": 1}) - act = (0, *act, np.zeros(i.max_inc_columns, dtype=np.float32)) + act = IndexSpaceRawSample(tuple([0, *act, np.zeros(i.max_inc_columns, dtype=np.float32)])) self.assertTrue(check_subspace(i, act)) neighbors = i.policy.structural_neighbors(act) From 4d1d118101649255f7c58cf956401ab2ba5b8a35 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:32:46 +0000 Subject: [PATCH 31/60] fixed tune/protox/tests/test_workload.py --- tune/protox/tests/test_workload.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tune/protox/tests/test_workload.py b/tune/protox/tests/test_workload.py index fb46fea3..79c6c45b 100644 --- a/tune/protox/tests/test_workload.py +++ b/tune/protox/tests/test_workload.py @@ -1,16 +1,18 @@ import json +from typing import Any, Tuple import unittest from pathlib import Path import yaml from tune.protox.env.space.primitive_space import IndexSpace +from tune.protox.env.types import TableAttrAccessSetsMap, TableColTuple from tune.protox.env.workload import Workload class WorkloadTests(unittest.TestCase): @staticmethod - def load(config_file: str, workload_path: Path): + def load(config_file: str, workload_path: Path) -> tuple[Workload, IndexSpace]: # don't call open_and_save() because this is a unittest with open(config_file, "r") as f: benchmark_config = yaml.safe_load(f) @@ -37,19 +39,19 @@ def load(config_file: str, workload_path: Path): seed=0, rel_metadata=w.column_usages(), attributes_overwrite=w.column_usages(), - tbl_include_subsets={}, + tbl_include_subsets=TableAttrAccessSetsMap({}), index_space_aux_type=True, index_space_aux_include=True, deterministic_policy=True, ) return w, i - def diff_classmapping(self, ref, target): + def diff_classmapping(self, ref: dict[TableColTuple, int], target: dict[TableColTuple, int]) -> None: for k, v in ref.items(): self.assertTrue(k in target, msg=f"{k} is missing.") self.assertTrue(v == target[k]) - def test_tpch(self): + def test_tpch(self) -> None: with open("tune/protox/tests/unittest_ref_models/ref_tpch_model.txt", "r") as f: ref = json.load(f)["class_mapping"] ref = {(v["relname"], v["ord_column"]): int(k) for k, v in ref.items()} @@ -60,7 +62,7 @@ def test_tpch(self): ) self.assertEqual(i.class_mapping, ref) - def test_job(self): + def test_job(self) -> None: # don't call open_and_save() because this is a unittest with open( "tune/protox/tests/unittest_ref_models/ref_job_full_model.txt", "r" @@ -74,7 +76,7 @@ def test_job(self): ) self.assertEqual(i.class_mapping, ref) - def test_dsb(self): + def test_dsb(self) -> None: # don't call open_and_save() because this is a unittest with open("tune/protox/tests/unittest_ref_models/ref_dsb_model.txt", "r") as f: ref = json.load(f)["class_mapping"] @@ -86,7 +88,7 @@ def test_dsb(self): ) self.diff_classmapping(ref, i.class_mapping) - def test_tpcc(self): + def test_tpcc(self) -> None: # don't call open_and_save() because this is a unittest with open("tune/protox/tests/unittest_ref_models/ref_tpcc_model.txt", "r") as f: ref = json.load(f)["class_mapping"] From 083b09e38da457405415a45929647e512b251ab9 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:34:23 +0000 Subject: [PATCH 32/60] fixed tune/protox/agent/wolp/policies.py --- tune/protox/agent/wolp/policies.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tune/protox/agent/wolp/policies.py b/tune/protox/agent/wolp/policies.py index c4294f0d..0d1f2a27 100644 --- a/tune/protox/agent/wolp/policies.py +++ b/tune/protox/agent/wolp/policies.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from gymnasium import spaces from numpy.typing import NDArray -from torch.optim import Optimizer +from torch.optim import Optimizer # type: ignore[attr-defined] from tune.protox.agent.buffers import ReplayBufferSamples from tune.protox.agent.noise import ActionNoise @@ -244,7 +244,7 @@ def train_critic( self.critic_optimizer.zero_grad() assert not th.isnan(critic_loss).any() critic_loss.backward() # type: ignore - th.nn.utils.clip_grad_norm_(list(self.critic.parameters()), self.grad_clip, error_if_nonfinite=True) # type: ignore + th.nn.utils.clip_grad_norm_(list(self.critic.parameters()), self.grad_clip, error_if_nonfinite=True) self.critic.check_grad() self.critic_optimizer.step() return critic_loss @@ -282,7 +282,7 @@ def train_actor(self, replay_data: ReplayBufferSamples) -> Any: self.actor_optimizer.zero_grad() assert not th.isnan(actor_loss).any() actor_loss.backward() # type: ignore - th.nn.utils.clip_grad_norm_(list(self.actor.parameters()), self.grad_clip, error_if_nonfinite=True) # type: ignore + th.nn.utils.clip_grad_norm_(list(self.actor.parameters()), self.grad_clip, error_if_nonfinite=True) self.actor.check_grad() self.actor_optimizer.step() return actor_loss From 45d93afee9ddfe8c46e5f65b9b49469de3c57669 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:37:09 +0000 Subject: [PATCH 33/60] fixed tune/protox/env/space/state/structure.py --- tune/protox/env/space/state/structure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tune/protox/env/space/state/structure.py b/tune/protox/env/space/state/structure.py index 04dbffdd..af29c1cf 100644 --- a/tune/protox/env/space/state/structure.py +++ b/tune/protox/env/space/state/structure.py @@ -116,7 +116,7 @@ def construct_offline( else: index_state = np.zeros(index_space.critic_dim(), dtype=np.float32) - state = {} + state: dict[str, Any] = {} if knob_state is not None: state["knobs"] = knob_state if query_state is not None: From 3eae78f6a74918acb433b67a04bf83285a5d751d Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:44:03 +0000 Subject: [PATCH 34/60] fixed tune/protox/env/workload.py --- tune/protox/env/types.py | 10 +++++----- tune/protox/env/workload.py | 39 ++++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index f6821ed4..ec7d22f6 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -153,12 +153,12 @@ class TargetResetConfig(TypedDict, total=False): class QuerySpec(TypedDict, total=False): benchbase: bool oltp_workload: bool - query_transactional: Union[str, Path] - query_directory: Union[str, Path] - query_order: Union[str, Path] + query_transactional: Path + query_directory: Path + query_order: Path - execute_query_directory: Union[str, Path] - execute_query_order: Union[str, Path] + execute_query_directory: Path + execute_query_order: Path tbl_include_subsets_prune: bool tbl_fold_subsets: bool diff --git a/tune/protox/env/workload.py b/tune/protox/env/workload.py index 3e64448b..5ec72f7e 100644 --- a/tune/protox/env/workload.py +++ b/tune/protox/env/workload.py @@ -5,10 +5,10 @@ import tempfile import time from pathlib import Path -from typing import Any, Optional, Tuple, Union, cast +from typing import IO, Any, Optional, Tuple, Union, cast import numpy as np -import pglast # type: ignore +import pglast from plumbum import local from misc.utils import DBGymConfig, open_and_save @@ -50,13 +50,10 @@ class Workload(object): # However, when creating a Workload object for unittesting, we just want to call open() def _open_for_reading( self, - path, - mode="r", - ): - # when opening for writing we always use open() so we don't need this function, which is - # why we assert here - # I still chose to make mode an argument just to make the interface identical to open()/open_and_save() - assert mode == "r" + path: Path, + ) -> IO[Any]: + # When opening for writing we always use open() so we don't need this function, which is + # why hardcode the mode as "r". if self.dbgym_cfg is not None: return open_and_save(self.dbgym_cfg, path) else: @@ -93,7 +90,7 @@ def _crunch( self.order.append(stem) self.queries_mix[stem] = ratio - with self._open_for_reading(sql_file, "r") as q: + with self._open_for_reading(sql_file) as q: sql = q.read() assert not sql.startswith("/*") @@ -256,7 +253,7 @@ def __init__( sqls = [] order_or_txn_fname = "txn.txt" if self.oltp_workload else "order.txt" workload_order_or_txn_fpath = self.workload_path / order_or_txn_fname - with self._open_for_reading(workload_order_or_txn_fpath, "r") as f: + with self._open_for_reading(workload_order_or_txn_fpath) as f: lines = f.read().splitlines() sqls = [ ( @@ -269,7 +266,7 @@ def __init__( # TODO(phw2): pass "query_transactional" somewhere other than query_spec, just like "query_order" is if "query_transactional" in query_spec: - with self._open_for_reading(query_spec["query_transactional"], "r") as f: + with self._open_for_reading(query_spec["query_transactional"]) as f: lines = f.read().splitlines() splits = [line.split(",") for line in lines] sqls = [ @@ -287,7 +284,7 @@ def __init__( # TODO(phw2): pass "execute_query_order" somewhere other than query_spec, just like "query_order" is if "execute_query_order" in query_spec: - with open_and_save(dbgym_cfg, query_spec["execute_query_order"], "r") as f: + with self._open_for_reading(query_spec["execute_query_order"]) as f: lines = f.read().splitlines() sqls = [ ( @@ -337,7 +334,12 @@ def max_indexable(self) -> int: def compute_total_workload_runtime( qid_runtime_data: dict[str, BestQueryRun] ) -> float: - return sum(best_run.runtime for best_run in qid_runtime_data.values()) / 1.0e6 + total_runtime: float = 0.0 + for best_run in qid_runtime_data.values(): + assert best_run.runtime is not None + total_runtime += best_run.runtime + total_runtime /= 1.0e6 + return total_runtime @time_record("execute") def execute_workload( @@ -345,13 +347,13 @@ def execute_workload( pg_conn: PostgresConn, actions: list[HolonAction] = [], variation_names: list[str] = [], - results_dpath: Optional[Union[str, Path]] = None, + results_dpath: Optional[Path] = None, observation_space: Optional[StateSpace] = None, action_space: Optional[HolonSpace] = None, reset_metrics: Optional[dict[str, BestQueryRun]] = None, override_workload_timeout: Optional[float] = None, query_timeout: Optional[int] = None, - workload_qdir: Optional[Tuple[Union[str, Path], Union[str, Path]]] = None, + workload_qdir: Optional[tuple[Path, Path]] = None, blocklist: list[str] = [], first: bool = False, ) -> Tuple[int, bool, dict[str, Any]]: @@ -391,7 +393,7 @@ def execute_workload( if workload_qdir is not None and workload_qdir[0] is not None: # Load actual queries to execute. workload_dir, workload_qlist = workload_qdir - with self._open_for_reading(workload_qlist, "r") as f: + with self._open_for_reading(workload_qlist) as f: psql_order = [ (f"Q{i+1}", Path(workload_dir) / l.strip()) for i, l in enumerate(f.readlines()) @@ -401,7 +403,7 @@ def execute_workload( actual_sql_files = {k: str(v) for (k, v) in psql_order} actual_queries = {} for qid, qpat in psql_order: - with self._open_for_reading(qpat, "r") as f: + with self._open_for_reading(qpat) as f: query = f.read() actual_queries[qid] = [(QueryType.SELECT, query)] else: @@ -674,6 +676,7 @@ def execute( # Execute benchbase if specified. success = self._execute_benchbase(benchbase_config, results_dpath) # We can only create a state if we succeeded. + assert self.dbgym_cfg is not None success = observation_space.check_benchbase(self.dbgym_cfg, results_dpath) else: num_timed_out_queries, did_workload_time_out, query_metric_data = ( From 330314813f62bf5f5cea12e66dc358119b68e048 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:48:14 +0000 Subject: [PATCH 35/60] fixed tune/protox/env/space/holon_space.py --- tune/protox/env/space/holon_space.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tune/protox/env/space/holon_space.py b/tune/protox/env/space/holon_space.py index ee51a5de..cc66ab4c 100644 --- a/tune/protox/env/space/holon_space.py +++ b/tune/protox/env/space/holon_space.py @@ -39,7 +39,7 @@ def _latent_assert_check( carprod_neighbors: list[HolonAction], carprod_embeds: torch.Tensor, first_drift: int, - ): + ) -> None: zero = self.to_latent([carprod_neighbors[0]])[0] last = self.to_latent([carprod_neighbors[-1]])[0] first_d = self.to_latent([carprod_neighbors[first_drift]])[0] @@ -232,19 +232,19 @@ def neighborhood( neighbor_parameters: NeighborParameters = DEFAULT_NEIGHBOR_PARAMETERS, ) -> Tuple[list[HolonAction], ProtoAction, torch.Tensor]: env_acts = [] - emb_acts: List[torch.Tensor] = [] + emb_acts: list[torch.Tensor] = [] ndims = [] env_action = self.from_latent(raw_action) for proto in env_action: # Figure out the neighbors for each subspace. - envs_neighbors = [] - embed_neighbors = [] + envs_neighbors: list[Any] = [] + embed_neighbors: list[Any] = [] # TODO(wz2,PROTOX_DELTA): For pseudo-backwards compatibility, we meld the knob + query space together. # In this way, we don't actually generate knob x query cartesian product. # Rather, we directly fuse min(knob_neighbors, query_neighbors) together and then cross with indexes. - meld_groups = [ + meld_groups: list[list[Any]] = [ [self.get_knob_space(), self.get_query_space()], [self.get_index_space()], ] From 04aae88550b540a7a42456c5098ec78713474aca Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:49:42 +0000 Subject: [PATCH 36/60] fixed tune/protox/agent/off_policy_algorithm.py --- tune/protox/agent/base_class.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tune/protox/agent/base_class.py b/tune/protox/agent/base_class.py index 3f999335..9b3dac26 100644 --- a/tune/protox/agent/base_class.py +++ b/tune/protox/agent/base_class.py @@ -6,6 +6,7 @@ import numpy as np from numpy.typing import NDArray +from misc.utils import TuningMode from tune.protox.agent.agent_env import AgentEnv from tune.protox.agent.noise import ActionNoise from tune.protox.env.logger import Logger @@ -75,7 +76,7 @@ def _setup_learn( return total_timesteps @abstractmethod - def learn(self, env: AgentEnv, total_timesteps: int) -> None: + def learn(self, env: AgentEnv, total_timesteps: int, tuning_mode: TuningMode) -> None: """ Return a trained model. From bb1043f9df355d36525979bb8cbef2b032f16578 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:50:24 +0000 Subject: [PATCH 37/60] fixed tune/protox/env/util/pg_conn.py --- tune/protox/env/util/pg_conn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tune/protox/env/util/pg_conn.py b/tune/protox/env/util/pg_conn.py index 77616ae9..cc89d722 100644 --- a/tune/protox/env/util/pg_conn.py +++ b/tune/protox/env/util/pg_conn.py @@ -272,7 +272,7 @@ def _set_up_boot( mu_hyp_opt: float, mu_hyp_time: int, mu_hyp_stdev: float, - ): + ) -> None: """ Sets up Boot on the currently running Postgres instances. Uses instance vars of PostgresConn for configuration. From 8d7d109243945fb1209cbe48533e20a5051d7fe2 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:52:48 +0000 Subject: [PATCH 38/60] fixed tune/protox/env/logger.py --- tune/protox/env/logger.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tune/protox/env/logger.py b/tune/protox/env/logger.py index 07459a18..36750f10 100644 --- a/tune/protox/env/logger.py +++ b/tune/protox/env/logger.py @@ -9,7 +9,7 @@ import numpy as np from plumbum import local -from torch.utils.tensorboard import SummaryWriter +from torch.utils.tensorboard.writer import SummaryWriter from typing_extensions import ParamSpec from misc.utils import DBGymConfig @@ -25,7 +25,7 @@ def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> T: ret = f(*args, **kwargs) # TODO(wz2): This is a hack to get a logger instance. - first_arg = args[0] # type: ignore[index] # Ignore the indexing type error + first_arg = args[0] # Ignore the indexing type error assert hasattr(first_arg, "logger"), print(first_arg, type(first_arg)) if first_arg.logger is None: @@ -82,7 +82,7 @@ def __init__( self.writer: Union[SummaryWriter, None] = None if self.trace: self.tensorboard_dpath.mkdir(parents=True, exist_ok=True) - self.writer = SummaryWriter(self.tensorboard_dpath) + self.writer = SummaryWriter(self.tensorboard_dpath) # type: ignore[no-untyped-call] self.iteration = 1 self.iteration_data: dict[str, Any] = {} @@ -145,14 +145,14 @@ def advance(self) -> None: for key, value in self.iteration_data.items(): if isinstance(value, str): # str is considered a np.ScalarType - self.writer.add_text(key, value, self.iteration) + self.writer.add_text(key, value, self.iteration) # type: ignore[no-untyped-call] else: - self.writer.add_scalar(key, value, self.iteration) + self.writer.add_scalar(key, value, self.iteration) # type: ignore[no-untyped-call] del self.iteration_data self.iteration_data = {} self.iteration += 1 - self.writer.flush() + self.writer.flush() # type: ignore[no-untyped-call] def record(self, key: str, value: Any) -> None: stack = inspect.stack(context=2) @@ -169,4 +169,4 @@ def flush(self) -> None: if self.trace: assert self.writer self.advance() - self.writer.flush() + self.writer.flush() # type: ignore[no-untyped-call] From cb2d0687a69fd38900a845993bce5696190c58f3 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 20:58:10 +0000 Subject: [PATCH 39/60] fixed tune/protox/tests/test_workload_utils.py --- tune/protox/tests/test_workload_utils.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tune/protox/tests/test_workload_utils.py b/tune/protox/tests/test_workload_utils.py index b1e63cf7..3f433f37 100644 --- a/tune/protox/tests/test_workload_utils.py +++ b/tune/protox/tests/test_workload_utils.py @@ -2,7 +2,8 @@ import pglast -from tune.protox.env.util.workload_analysis import * +from tune.protox.env.types import QueryType, AttrTableListMap +from tune.protox.env.util.workload_analysis import extract_aliases, extract_sqltypes, extract_columns class WorkloadUtilsTests(unittest.TestCase): @@ -16,7 +17,7 @@ class WorkloadUtilsTests(unittest.TestCase): "nation", "region", ] - TPCH_ALL_ATTRIBUTES = { + TPCH_ALL_ATTRIBUTES = AttrTableListMap({ "r_regionkey": ["region"], "r_name": ["region"], "r_comment": ["region"], @@ -78,7 +79,7 @@ class WorkloadUtilsTests(unittest.TestCase): "l_shipinstruct": ["lineitem"], "l_shipmode": ["lineitem"], "l_comment": ["lineitem"], - } + }) TPCH_Q1 = """ select l_returnflag, @@ -104,10 +105,10 @@ class WorkloadUtilsTests(unittest.TestCase): """ @staticmethod - def pglast_parse(sql): + def pglast_parse(sql: str) -> pglast.ast.Node: return pglast.parse_sql(sql) - def test_extract_aliases(self): + def test_extract_aliases(self) -> None: sql = "select * from t1 as t1_alias; select * from t1;" stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) @@ -116,21 +117,21 @@ def test_extract_aliases(self): self.assertTrue("t1" in aliases and len(aliases) == 1) self.assertEqual(set(aliases["t1"]), set(["t1", "t1_alias"])) - def test_extract_aliases_ignores_views_in_create_view(self): + def test_extract_aliases_ignores_views_in_create_view(self) -> None: sql = "create view view1 (view1_c1) as select c1 from t1;" stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) # all tables have only one alias so we can do this simpler assertion code self.assertEqual(aliases, {"t1": ["t1"]}) - def test_extract_aliases_doesnt_ignore_views_that_are_used(self): + def test_extract_aliases_doesnt_ignore_views_that_are_used(self) -> None: sql = "create view view1 (view1_c1) as select c1 from t1; select * from view1;" stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) # all tables have only one alias so we can do this simpler assertion code self.assertEqual(aliases, {"t1": ["t1"], "view1": ["view1"]}) - def test_extract_sqltypes(self): + def test_extract_sqltypes(self) -> None: sql = """ select * from t1; update t1 set t1.c1 = 0 where t1.c1 = 1; @@ -150,7 +151,7 @@ def test_extract_sqltypes(self): self.assertEqual(sqltypes[1][0], QueryType.INS_UPD_DEL) self.assertEqual(sqltypes[2][0], QueryType.CREATE_VIEW) - def test_extract_columns(self): + def test_extract_columns(self) -> None: sql = WorkloadUtilsTests.TPCH_Q1 tables = WorkloadUtilsTests.TPCH_TABLES all_attributes = WorkloadUtilsTests.TPCH_ALL_ATTRIBUTES @@ -194,7 +195,7 @@ def test_extract_columns(self): ), ) - def test_extract_columns_with_cte(self): + def test_extract_columns_with_cte(self) -> None: sql = """ with cte1 as ( select t1.c1 @@ -205,7 +206,7 @@ def test_extract_columns_with_cte(self): from cte1; """ tables = ["t1"] - all_attributes = {"c1": "t1", "c2": "t1"} + all_attributes = AttrTableListMap({"c1": ["t1"], "c2": ["t1"]}) stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) self.assertEqual(len(stmts), 1) From 47c762edfd0ada6237e2e41ff2c7ef06ccf77a4c Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:00:36 +0000 Subject: [PATCH 40/60] fixed tune/protox/env/util/workload_analysis.py --- tune/protox/env/util/workload_analysis.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tune/protox/env/util/workload_analysis.py b/tune/protox/env/util/workload_analysis.py index de4be450..9db06e5e 100644 --- a/tune/protox/env/util/workload_analysis.py +++ b/tune/protox/env/util/workload_analysis.py @@ -1,9 +1,8 @@ -from enum import Enum, unique from typing import Iterator, Optional, Tuple -import pglast # type: ignore +import pglast from pglast import stream -from pglast.visitors import Continue, Visitor # type: ignore +from pglast.visitors import Continue, Visitor from tune.protox.env.types import ( AttrTableListMap, From 2f68f7d02a93f711a97e103540017d34fece3206 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:02:57 +0000 Subject: [PATCH 41/60] fixed tune/protox/tests/test_primitive.py --- tune/protox/env/space/primitive/knob.py | 2 +- tune/protox/tests/test_primitive.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tune/protox/env/space/primitive/knob.py b/tune/protox/env/space/primitive/knob.py index f71e397f..7905bf70 100644 --- a/tune/protox/env/space/primitive/knob.py +++ b/tune/protox/env/space/primitive/knob.py @@ -44,7 +44,7 @@ class KnobMetadata(TypedDict, total=False): type: str min: float max: float - quantize: bool + quantize: int log_scale: int unit: int values: list[str] diff --git a/tune/protox/tests/test_primitive.py b/tune/protox/tests/test_primitive.py index f9c2bd29..d7590d80 100644 --- a/tune/protox/tests/test_primitive.py +++ b/tune/protox/tests/test_primitive.py @@ -7,7 +7,7 @@ class PrimitivesTests(unittest.TestCase): - def test_linear_knob(self): + def test_linear_knob(self) -> None: k = Knob( table_name=None, query_name="q", @@ -30,7 +30,7 @@ def test_linear_knob(self): self.assertEqual(k.project_scraped_setting(0.58), 0.5) self.assertEqual(round(k.project_scraped_setting(0.62), 2), 0.6) - def test_log_knob(self): + def test_log_knob(self) -> None: k = Knob( table_name=None, query_name="q", @@ -53,7 +53,7 @@ def test_log_knob(self): self.assertEqual(k.project_scraped_setting(24), 32.0) self.assertEqual(k.project_scraped_setting(1024), 1024.0) - def test_latent_knob(self): + def test_latent_knob(self) -> None: k = LatentKnob( table_name=None, query_name="q", @@ -85,7 +85,7 @@ def test_latent_knob(self): self.assertEqual(k.shift_offset(0.5, 1), 0.6) self.assertEqual(k.shift_offset(0.5, -2), 0.3) - def test_ia(self): + def test_ia(self) -> None: ia1 = IndexAction( idx_type="btree", tbl="tbl", @@ -95,7 +95,7 @@ def test_ia(self): raw_repr=None, bias=0.0, ) - IndexAction.index_counter = 0 + IndexAction.index_name_counter = 0 self.assertEqual( ia1.sql(add=True), "CREATE INDEX index0 ON tbl USING btree (a,b,c) INCLUDE (d,e)", From 09bdf394a70205c4bb73054264e8daebdc17719e Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:05:31 +0000 Subject: [PATCH 42/60] fixed benchmark/tpch/cli.py --- benchmark/cli.py | 2 +- benchmark/tpch/cli.py | 20 ++++++++++---------- dbms/cli.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/benchmark/cli.py b/benchmark/cli.py index 2edf5d7f..cd58d55e 100644 --- a/benchmark/cli.py +++ b/benchmark/cli.py @@ -6,7 +6,7 @@ @click.group(name="benchmark") @click.pass_obj -def benchmark_group(dbgym_cfg: DBGymConfig): +def benchmark_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("benchmark") diff --git a/benchmark/tpch/cli.py b/benchmark/tpch/cli.py index 975fd769..270cb629 100644 --- a/benchmark/tpch/cli.py +++ b/benchmark/tpch/cli.py @@ -1,6 +1,5 @@ import logging -import os -import shutil +from pathlib import Path import click @@ -10,7 +9,6 @@ link_result, workload_name_fn, ) -from util.pg import * from util.shell import subprocess_run benchmark_tpch_logger = logging.getLogger("benchmark/tpch") @@ -19,7 +17,7 @@ @click.group(name="tpch") @click.pass_obj -def tpch_group(dbgym_cfg: DBGymConfig): +def tpch_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("tpch") @@ -28,7 +26,7 @@ def tpch_group(dbgym_cfg: DBGymConfig): @click.pass_obj # The reason generate data is separate from create dbdata is because generate-data is generic # to all DBMSs while create dbdata is specific to a single DBMS. -def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float): +def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None: _clone(dbgym_cfg) _generate_data(dbgym_cfg, scale_factor) @@ -59,7 +57,7 @@ def tpch_workload( seed_end: int, query_subset: str, scale_factor: float, -): +) -> None: assert ( seed_start <= seed_end ), f"seed_start ({seed_start}) must be <= seed_end ({seed_end})" @@ -72,7 +70,7 @@ def _get_queries_dname(seed: int, scale_factor: float) -> str: return f"queries_{seed}_sf{get_scale_factor_string(scale_factor)}" -def _clone(dbgym_cfg: DBGymConfig): +def _clone(dbgym_cfg: DBGymConfig) -> None: expected_symlink_dpath = ( dbgym_cfg.cur_symlinks_build_path(mkdir=True) / "tpch-kit.link" ) @@ -102,7 +100,7 @@ def _get_tpch_kit_dpath(dbgym_cfg: DBGymConfig) -> Path: def _generate_queries( dbgym_cfg: DBGymConfig, seed_start: int, seed_end: int, scale_factor: float -): +) -> None: tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg) data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True) benchmark_tpch_logger.info( @@ -132,7 +130,7 @@ def _generate_queries( ) -def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float): +def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None: tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg) data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True) expected_tables_symlink_dpath = ( @@ -162,7 +160,7 @@ def _generate_workload( seed_end: int, query_subset: str, scale_factor: float, -): +) -> None: symlink_data_dpath = dbgym_cfg.cur_symlinks_data_path(mkdir=True) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) expected_workload_symlink_dpath = symlink_data_dpath / (workload_name + ".link") @@ -177,6 +175,8 @@ def _generate_workload( queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 0] elif query_subset == "odd": queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 1] + else: + assert False with open(real_dpath / "order.txt", "w") as f: for seed in range(seed_start, seed_end + 1): diff --git a/dbms/cli.py b/dbms/cli.py index b71bed18..e53c3113 100644 --- a/dbms/cli.py +++ b/dbms/cli.py @@ -5,7 +5,7 @@ @click.group(name="dbms") @click.pass_obj -def dbms_group(dbgym_cfg): +def dbms_group(dbgym_cfg) -> None: dbgym_cfg.append_group("dbms") From c80f81430fb1b5f7ffe2b6639ca88021cf1a24a6 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:09:57 +0000 Subject: [PATCH 43/60] fixed dbms/postgres/cli.py --- dbms/cli.py | 3 ++- dbms/load_info_base_class.py | 9 ++++++--- dbms/postgres/cli.py | 27 ++++++++++++++++----------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/dbms/cli.py b/dbms/cli.py index e53c3113..990f096c 100644 --- a/dbms/cli.py +++ b/dbms/cli.py @@ -1,11 +1,12 @@ import click from dbms.postgres.cli import postgres_group +from misc.utils import DBGymConfig @click.group(name="dbms") @click.pass_obj -def dbms_group(dbgym_cfg) -> None: +def dbms_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("dbms") diff --git a/dbms/load_info_base_class.py b/dbms/load_info_base_class.py index 99b1032e..c09ac6b3 100644 --- a/dbms/load_info_base_class.py +++ b/dbms/load_info_base_class.py @@ -1,3 +1,6 @@ +from pathlib import Path + + class LoadInfoBaseClass: """ A base class for providing info for DBMSs to load the data of a benchmark @@ -5,12 +8,12 @@ class LoadInfoBaseClass: copy the comments or type annotations or else they might become out of sync. """ - def get_schema_fpath(self) -> str: + def get_schema_fpath(self) -> Path: raise NotImplemented - def get_tables_and_fpaths(self) -> list[tuple[str, str]]: + def get_tables_and_fpaths(self) -> list[tuple[str, Path]]: raise NotImplemented # If the subclassing benchmark does not have constraints, you can return None here - def get_constraints_fpath(self) -> str | None: + def get_constraints_fpath(self) -> Path | None: raise NotImplemented diff --git a/dbms/postgres/cli.py b/dbms/postgres/cli.py index 6dd6e40e..ff729b73 100644 --- a/dbms/postgres/cli.py +++ b/dbms/postgres/cli.py @@ -10,6 +10,7 @@ import shutil import subprocess from pathlib import Path +from typing import Optional import click from sqlalchemy import Connection @@ -47,7 +48,7 @@ @click.group(name="postgres") @click.pass_obj -def postgres_group(dbgym_cfg: DBGymConfig): +def postgres_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("postgres") @@ -61,7 +62,7 @@ def postgres_group(dbgym_cfg: DBGymConfig): is_flag=True, help="Include this flag to rebuild Postgres even if it already exists.", ) -def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool): +def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool) -> None: _build_repo(dbgym_cfg, rebuild) @@ -94,10 +95,10 @@ def postgres_dbdata( dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float, - pgbin_path: Path, + pgbin_path: Optional[Path], intended_dbdata_hardware: str, - dbdata_parent_dpath: Path, -): + dbdata_parent_dpath: Optional[Path], +) -> None: # Set args to defaults programmatically (do this before doing anything else in the function) if pgbin_path is None: pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path) @@ -138,7 +139,7 @@ def _get_repo_symlink_path(dbgym_cfg: DBGymConfig) -> Path: return dbgym_cfg.cur_symlinks_build_path("repo.link") -def _build_repo(dbgym_cfg: DBGymConfig, rebuild): +def _build_repo(dbgym_cfg: DBGymConfig, rebuild: bool) -> None: expected_repo_symlink_dpath = _get_repo_symlink_path(dbgym_cfg) if not rebuild and expected_repo_symlink_dpath.exists(): dbms_postgres_logger.info( @@ -209,7 +210,7 @@ def _create_dbdata( dbms_postgres_logger.info(f"Created dbdata in {dbdata_tgz_symlink_path}") -def _generic_dbdata_setup(dbgym_cfg: DBGymConfig): +def _generic_dbdata_setup(dbgym_cfg: DBGymConfig) -> None: # get necessary vars pgbin_real_dpath = _get_pgbin_symlink_path(dbgym_cfg).resolve() assert pgbin_real_dpath.exists() @@ -247,7 +248,7 @@ def _generic_dbdata_setup(dbgym_cfg: DBGymConfig): def _load_benchmark_into_dbdata( dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float -): +) -> None: with create_conn(use_psycopg=False) as conn: if benchmark_name == "tpch": load_info = TpchLoadInfo(dbgym_cfg, scale_factor) @@ -261,7 +262,7 @@ def _load_benchmark_into_dbdata( def _load_into_dbdata( dbgym_cfg: DBGymConfig, conn: Connection, load_info: LoadInfoBaseClass -): +) -> None: sql_file_execute(dbgym_cfg, conn, load_info.get_schema_fpath()) # truncate all tables first before even loading a single one @@ -270,13 +271,17 @@ def _load_into_dbdata( # then, load the tables for table, table_fpath in load_info.get_tables_and_fpaths(): with open_and_save(dbgym_cfg, table_fpath, "r") as table_csv: - with conn.connection.dbapi_connection.cursor() as cur: + assert conn.connection.dbapi_connection is not None + cur = conn.connection.dbapi_connection.cursor() + try: with cur.copy(f"COPY {table} FROM STDIN CSV DELIMITER '|'") as copy: while data := table_csv.read(8192): copy.write(data) + finally: + cur.close() constraints_fpath = load_info.get_constraints_fpath() - if constraints_fpath != None: + if constraints_fpath is not None: sql_file_execute(dbgym_cfg, conn, constraints_fpath) From e39c151313aaf231a92e1576b459d66258cbe6d4 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:31:48 +0000 Subject: [PATCH 44/60] fixed manage/tests/test_clean.py --- manage/cli.py | 14 +- manage/tests/test_clean.py | 324 ++++++++++++++++++------------------- 2 files changed, 171 insertions(+), 167 deletions(-) diff --git a/manage/cli.py b/manage/cli.py index 3f3cba2e..d4e6e22e 100644 --- a/manage/cli.py +++ b/manage/cli.py @@ -8,12 +8,22 @@ import click import yaml -from misc.utils import DBGymConfig, is_child_path, parent_dpath_of_path +from misc.utils import DBGymConfig, get_runs_path_from_workspace_path, get_symlinks_path_from_workspace_path, is_child_path, parent_dpath_of_path task_logger = logging.getLogger("task") task_logger.setLevel(logging.INFO) +# This is used in test_clean.py. It's defined here to avoid a circular import. +class MockDBGymConfig: + def __init__(self, scratchspace_path: Path): + self.dbgym_workspace_path = scratchspace_path + self.dbgym_symlinks_path = get_symlinks_path_from_workspace_path( + scratchspace_path + ) + self.dbgym_runs_path = get_runs_path_from_workspace_path(scratchspace_path) + + @click.group(name="manage") def manage_group(): pass @@ -136,7 +146,7 @@ def _count_files_in_workspace(dbgym_cfg: DBGymConfig) -> int: return total_count -def clean_workspace(dbgym_cfg: DBGymConfig, mode: str = "safe", verbose=False) -> None: +def clean_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe", verbose=False) -> None: """ Clean all [workspace]/task_runs/run_*/ directories that are not referenced by any "active symlinks". If mode is "aggressive", "active symlinks" means *only* the symlinks directly in [workspace]/symlinks/. diff --git a/manage/tests/test_clean.py b/manage/tests/test_clean.py index 2ba24249..27fcc305 100644 --- a/manage/tests/test_clean.py +++ b/manage/tests/test_clean.py @@ -2,13 +2,12 @@ import logging import os import shutil +from typing import Any, NewType, cast import unittest from pathlib import Path -from manage.cli import clean_workspace +from manage.cli import MockDBGymConfig, clean_workspace from misc.utils import ( - get_runs_path_from_workspace_path, - get_symlinks_path_from_workspace_path, path_exists_dont_follow_symlinks, ) @@ -18,13 +17,7 @@ logging.basicConfig(level=logging.INFO) -class MockDBGymConfig: - def __init__(self, scratchspace_path: Path): - self.dbgym_workspace_path = scratchspace_path - self.dbgym_symlinks_path = get_symlinks_path_from_workspace_path( - scratchspace_path - ) - self.dbgym_runs_path = get_runs_path_from_workspace_path(scratchspace_path) +FilesystemStructure = NewType("FilesystemStructure", dict[str, Any]) class CleanTests(unittest.TestCase): @@ -32,18 +25,19 @@ class CleanTests(unittest.TestCase): I deemed "clean" important enough to write extensive unit tests for because a bug could lead to losing important files. """ + scratchspace_path: Path = Path() @staticmethod - def create_structure(root_path: Path, structure: dict) -> None: + def create_structure(root_path: Path, structure: FilesystemStructure) -> None: def create_structure_internal( - root_path: Path, cur_path: Path, structure: dict + root_path: Path, cur_path: Path, structure: FilesystemStructure ) -> None: for path, content in structure.items(): full_path: Path = cur_path / path if isinstance(content, dict): # Directory full_path.mkdir(parents=True, exist_ok=True) - create_structure_internal(root_path, full_path, content) + create_structure_internal(root_path, full_path, FilesystemStructure(cast(dict[str, Any], content))) elif isinstance(content, tuple) and content[0] == "file": assert len(content) == 1 full_path.touch() @@ -58,9 +52,9 @@ def create_structure_internal( create_structure_internal(root_path, root_path, structure) @staticmethod - def verify_structure(root_path: Path, structure: dict) -> bool: + def verify_structure(root_path: Path, structure: FilesystemStructure) -> bool: def verify_structure_internal( - root_path: Path, cur_path: Path, structure: dict + root_path: Path, cur_path: Path, structure: FilesystemStructure ) -> bool: # Check for the presence of each item specified in the structure for name, item in structure.items(): @@ -72,7 +66,7 @@ def verify_structure_internal( if not new_cur_path.is_dir(): logging.debug(f"expected {new_cur_path} to be a directory") return False - if not verify_structure_internal(root_path, new_cur_path, item): + if not verify_structure_internal(root_path, new_cur_path, FilesystemStructure(cast(dict[str, Any], item))): return False elif isinstance(item, tuple) and item[0] == "file": if not new_cur_path.is_file(): @@ -111,36 +105,36 @@ def verify_structure_internal( @staticmethod def make_workspace_structure( - symlinks_structure: dict, task_runs_structure: dict - ) -> dict: + symlinks_structure: FilesystemStructure, task_runs_structure: FilesystemStructure + ) -> FilesystemStructure: """ This function exists so that it's easier to refactor the tests in case we ever change how the workspace is organized. """ - return { + return FilesystemStructure({ "symlinks": symlinks_structure, "task_runs": task_runs_structure, - } + }) @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.scratchspace_path = Path.cwd() / "manage/tests/test_clean_scratchspace/" - def setUp(self): + def setUp(self) -> None: if self.scratchspace_path.exists(): shutil.rmtree(self.scratchspace_path) - def tearDown(self): + def tearDown(self) -> None: if self.scratchspace_path.exists(): shutil.rmtree(self.scratchspace_path) - def test_structure_helpers(self): - structure = { + def test_structure_helpers(self) -> None: + structure = FilesystemStructure({ "dir1": {"file1.txt": ("file",), "dir2": {"file2.txt": ("file",)}}, "dir3": {"nested_link_to_dir1": ("symlink", "dir1")}, "link_to_dir1": ("symlink", "dir1"), "link_to_file2": ("symlink", "dir1/dir2/file2.txt"), - } + }) CleanTests.create_structure(self.scratchspace_path, structure) self.assertTrue(CleanTests.verify_structure(self.scratchspace_path, structure)) @@ -214,44 +208,44 @@ def test_structure_helpers(self): CleanTests.verify_structure(self.scratchspace_path, wrong_link_structure) ) - def test_nonexistent_workspace(self): + def test_nonexistent_workspace(self) -> None: clean_workspace(MockDBGymConfig(self.scratchspace_path)) - def test_no_symlinks_dir_and_no_task_runs_dir(self): - starting_structure = {} - ending_structure = {} + def test_no_symlinks_dir_and_no_task_runs_dir(self) -> None: + starting_structure = FilesystemStructure({}) + ending_structure = FilesystemStructure({}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path)) self.assertTrue( CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_no_symlinks_dir_and_yes_task_runs_dir(self): - starting_structure = {"task_runs": {"file1.txt": ("file",)}} - ending_structure = {"task_runs": {}} + def test_no_symlinks_dir_and_yes_task_runs_dir(self) -> None: + starting_structure = FilesystemStructure({"task_runs": {"file1.txt": ("file",)}}) + ending_structure = FilesystemStructure({"task_runs": {}}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path)) self.assertTrue( CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_yes_symlinks_dir_and_no_task_runs_dir(self): - starting_structure = {"symlinks": {}} - ending_structure = {"symlinks": {}} + def test_yes_symlinks_dir_and_no_task_runs_dir(self) -> None: + starting_structure = FilesystemStructure({"symlinks": {}}) + ending_structure = FilesystemStructure({"symlinks": {}}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path)) self.assertTrue( CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_no_symlinks_in_dir_and_no_task_runs_in_dir(self): - starting_symlinks_structure = {} - starting_task_runs_structure = {} + def test_no_symlinks_in_dir_and_no_task_runs_in_dir(self) -> None: + starting_symlinks_structure = FilesystemStructure({}) + starting_task_runs_structure = FilesystemStructure({}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {} - ending_task_runs_structure = {} + ending_symlinks_structure = FilesystemStructure({}) + ending_task_runs_structure = FilesystemStructure({}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -262,14 +256,14 @@ def test_no_symlinks_in_dir_and_no_task_runs_in_dir(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_no_links_in_symlinks(self): - starting_symlinks_structure = {} - starting_task_runs_structure = {"run_0": {}} + def test_no_links_in_symlinks(self) -> None: + starting_symlinks_structure = FilesystemStructure({}) + starting_task_runs_structure = FilesystemStructure({"run_0": {}}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {} - ending_task_runs_structure = {} + ending_symlinks_structure = FilesystemStructure({}) + ending_task_runs_structure = FilesystemStructure({}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -280,14 +274,14 @@ def test_no_links_in_symlinks(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_file_directly_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/file1.txt")} - starting_task_runs_structure = {"file1.txt": ("file",), "file2.txt": ("file",)} + def test_link_to_file_directly_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/file1.txt")}) + starting_task_runs_structure = FilesystemStructure({"file1.txt": ("file",), "file2.txt": ("file",)}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/file1.txt")} - ending_task_runs_structure = {"file1.txt": ("file",)} + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/file1.txt")}) + ending_task_runs_structure = FilesystemStructure({"file1.txt": ("file",)}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -298,17 +292,17 @@ def test_link_to_file_directly_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_dir_directly_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + def test_link_to_dir_directly_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"file1.txt": ("file",)}, "dir2": {"file2.txt": ("file",)}, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = {"dir1": {"file1.txt": ("file",)}} + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({"dir1": {"file1.txt": ("file",)}}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -319,21 +313,21 @@ def test_link_to_dir_directly_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_file_in_dir_in_task_runs(self): - starting_symlinks_structure = { + def test_link_to_file_in_dir_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - starting_task_runs_structure = { + }) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"file1.txt": ("file",)}, "dir2": {"file2.txt": ("file",)}, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = { + ending_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - ending_task_runs_structure = {"dir1": {"file1.txt": ("file",)}} + }) + ending_task_runs_structure = FilesystemStructure({"dir1": {"file1.txt": ("file",)}}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -344,19 +338,19 @@ def test_link_to_file_in_dir_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_dir_in_dir_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1/dir2")} - starting_task_runs_structure = { + def test_link_to_dir_in_dir_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1/dir2")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, "dir3": {"file3.txt": ("file",)}, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1/dir2")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1/dir2")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -367,12 +361,12 @@ def test_link_to_dir_in_dir_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_link_crashes(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/symlink2")} - starting_task_runs_structure = { + def test_link_to_link_crashes(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/symlink2")}) + starting_task_runs_structure = FilesystemStructure({ "symlink2": ("symlink", "task_runs/file1.txt"), "file1.txt": ("file",), - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -381,21 +375,21 @@ def test_link_to_link_crashes(self): with self.assertRaises(AssertionError): clean_workspace(MockDBGymConfig(self.scratchspace_path)) - def test_safe_mode_link_to_dir_with_link(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + def test_safe_mode_link_to_dir_with_link(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, "file1.txt": ("file",), "file2.txt": ("file",), - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, "file1.txt": ("file",), - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -406,31 +400,31 @@ def test_safe_mode_link_to_dir_with_link(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_safe_mode_link_to_file_in_dir_with_link(self): - starting_symlinks_structure = { + def test_safe_mode_link_to_file_in_dir_with_link(self) -> None: + starting_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - starting_task_runs_structure = { + }) + starting_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/file2.txt"), }, "file2.txt": ("file",), "file3.txt": ("file",), - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = { + ending_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - ending_task_runs_structure = { + }) + ending_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/file2.txt"), }, "file2.txt": ("file",), - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -441,25 +435,25 @@ def test_safe_mode_link_to_file_in_dir_with_link(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, "dir2": { "file2.txt": ("file",), }, "file3.txt": ("file",), - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, "dir2": { "file2.txt": ("file",), }, - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -470,20 +464,20 @@ def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_aggressive_mode_link_to_dir_with_link(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + def test_aggressive_mode_link_to_dir_with_link(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, "file1.txt": ("file",), "file2.txt": ("file",), - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", None)}, - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -494,14 +488,14 @@ def test_aggressive_mode_link_to_dir_with_link(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_link_to_file_gives_error(self): - starting_symlinks_structure = { + def test_link_to_link_to_file_gives_error(self) -> None: + starting_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/symlink2") - } - starting_task_runs_structure = { + }) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", "task_runs/file2.txt")}, "file2.txt": ("file",), - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -512,13 +506,13 @@ def test_link_to_link_to_file_gives_error(self): with self.assertRaises(AssertionError): clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") - def test_multi_link_loop_gives_error(self): - starting_symlinks_structure = { + def test_multi_link_loop_gives_error(self) -> None: + starting_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/symlink2") - } - starting_task_runs_structure = { + }) + starting_task_runs_structure = FilesystemStructure({ "dir1": {"symlink2": ("symlink", "symlinks/symlink1")}, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -529,9 +523,9 @@ def test_multi_link_loop_gives_error(self): with self.assertRaises(RuntimeError): clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") - def test_link_self_loop_gives_error(self): - starting_symlinks_structure = {"symlink1": ("symlink", "symlinks/symlink1")} - starting_task_runs_structure = dict() + def test_link_self_loop_gives_error(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "symlinks/symlink1")}) + starting_task_runs_structure = FilesystemStructure({}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -544,9 +538,9 @@ def test_link_self_loop_gives_error(self): def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs( self, - ): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + ) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/dir2/file2.txt"), @@ -555,12 +549,12 @@ def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs "file2.txt": ("file",), "symlink2": ("symlink", "task_runs/dir1/file1.txt"), }, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/dir2/file2.txt"), @@ -569,7 +563,7 @@ def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs "file2.txt": ("file",), "symlink2": ("symlink", "task_runs/dir1/file1.txt"), }, - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -582,24 +576,24 @@ def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs def test_dont_loop_infinitely_if_there_is_a_dir_in_runs_that_links_to_a_file_in_itself( self, - ): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + ) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/dir1/file1.txt"), }, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/dir1/file1.txt"), }, - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -610,24 +604,24 @@ def test_dont_loop_infinitely_if_there_is_a_dir_in_runs_that_links_to_a_file_in_ CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/dir1/file1.txt"), }, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/dir1/file1.txt"), }, - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -638,22 +632,22 @@ def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_broken_symlink_has_no_effect(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { + def test_broken_symlink_has_no_effect(self) -> None: + starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + starting_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "task_runs/dir1/non_existent_file.txt"), }, "dir2": {"file2.txt": ("file",)}, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { + ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) + ending_task_runs_structure = FilesystemStructure({ "dir1": {"file1.txt": ("file",), "symlink2": ("symlink", None)} - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -667,35 +661,35 @@ def test_broken_symlink_has_no_effect(self): # The idea behind this test is that we shouldn't be following links outside of task_runs, even on safe mode def test_link_to_folder_outside_runs_that_contains_link_to_other_run_doesnt_save_other_run( self, - ): - starting_symlinks_structure = { + ) -> None: + starting_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - starting_task_runs_structure = { + }) + starting_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "external/dir3/file3.txt"), }, "dir2": {"file2.txt": ("file",)}, - } + }) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - starting_structure["external"] = { + starting_structure["external"] = FilesystemStructure({ "dir3": { "file3.txt": ("file",), "symlink3": ("symlink", "task_runs/dir2/file2.txt"), } - } - ending_symlinks_structure = { + }) + ending_symlinks_structure = FilesystemStructure({ "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - ending_task_runs_structure = { + }) + ending_task_runs_structure = FilesystemStructure({ "dir1": { "file1.txt": ("file",), "symlink2": ("symlink", "external/dir3/file3.txt"), } - } + }) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -709,19 +703,19 @@ def test_link_to_folder_outside_runs_that_contains_link_to_other_run_doesnt_save CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_outside_task_runs_doesnt_get_deleted(self): - starting_symlinks_structure = {} - starting_task_runs_structure = {"dir1": {}} + def test_outside_task_runs_doesnt_get_deleted(self) -> None: + starting_symlinks_structure = FilesystemStructure({}) + starting_task_runs_structure = FilesystemStructure({"dir1": {}}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - starting_structure["external"] = {"file1.txt": ("file",)} - ending_symlinks_structure = {} - ending_task_runs_structure = {} + starting_structure["external"] = FilesystemStructure({"file1.txt": ("file",)}) + ending_symlinks_structure = FilesystemStructure({}) + ending_task_runs_structure = FilesystemStructure({}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) - ending_structure["external"] = {"file1.txt": ("file",)} + ending_structure["external"] = FilesystemStructure({"file1.txt": ("file",)}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") From b93d3f47c36516ec9e1f23bd2134ea93a57e90e0 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:33:05 +0000 Subject: [PATCH 45/60] fixed benchmark/tpch/load_info.py --- benchmark/tpch/load_info.py | 8 +++++--- dbms/load_info_base_class.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmark/tpch/load_info.py b/benchmark/tpch/load_info.py index 2c84ac2b..e678c6be 100644 --- a/benchmark/tpch/load_info.py +++ b/benchmark/tpch/load_info.py @@ -1,3 +1,5 @@ +from pathlib import Path +from typing import Optional from dbms.load_info_base_class import LoadInfoBaseClass from misc.utils import DBGymConfig, get_scale_factor_string @@ -55,11 +57,11 @@ def __init__(self, dbgym_cfg: DBGymConfig, scale_factor: float): table_fpath = tables_dpath / f"{table}.tbl" self._tables_and_fpaths.append((table, table_fpath)) - def get_schema_fpath(self): + def get_schema_fpath(self) -> Path: return self._schema_fpath - def get_tables_and_fpaths(self): + def get_tables_and_fpaths(self) -> list[tuple[str, Path]]: return self._tables_and_fpaths - def get_constraints_fpath(self): + def get_constraints_fpath(self) -> Optional[Path]: return self._constraints_fpath diff --git a/dbms/load_info_base_class.py b/dbms/load_info_base_class.py index c09ac6b3..40df2590 100644 --- a/dbms/load_info_base_class.py +++ b/dbms/load_info_base_class.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Optional class LoadInfoBaseClass: @@ -15,5 +16,5 @@ def get_tables_and_fpaths(self) -> list[tuple[str, Path]]: raise NotImplemented # If the subclassing benchmark does not have constraints, you can return None here - def get_constraints_fpath(self) -> Path | None: + def get_constraints_fpath(self) -> Optional[Path]: raise NotImplemented From fcf68942a89186add3faa28ca46bb359e1f11bf4 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:36:14 +0000 Subject: [PATCH 46/60] fixed manage/cli.py --- manage/cli.py | 80 ++++++--------------------------------------------- 1 file changed, 8 insertions(+), 72 deletions(-) diff --git a/manage/cli.py b/manage/cli.py index d4e6e22e..e0fd0aa9 100644 --- a/manage/cli.py +++ b/manage/cli.py @@ -25,71 +25,10 @@ def __init__(self, scratchspace_path: Path): @click.group(name="manage") -def manage_group(): +def manage_group() -> None: pass -@click.command(name="show") -@click.argument("keys", nargs=-1) -@click.pass_obj -def manage_show(dbgym_cfg, keys): - config_path = dbgym_cfg.path - config_yaml = dbgym_cfg.yaml - - # Traverse the YAML. - for key in keys: - config_yaml = config_yaml[key] - - # Pretty-print the requested YAML value. - output_str = None - if type(config_yaml) != dict: - output_str = config_yaml - else: - output_str = yaml.dump(config_yaml, default_flow_style=False) - if len(keys) > 0: - output_str = " " + output_str.replace("\n", "\n ") - output_str = output_str.rstrip() - print(output_str) - - task_logger.info(f"Read: {Path(config_path)}") - - -@click.command(name="write") -@click.argument("keys", nargs=-1) -@click.argument("value_type") -@click.argument("value") -@click.pass_obj -def manage_write(dbgym_cfg, keys, value_type, value): - config_path = dbgym_cfg.path - config_yaml = dbgym_cfg.yaml - - # Traverse the YAML. - root_yaml = config_yaml - for key in keys[:-1]: - config_yaml = config_yaml[key] - - # Modify the requested YAML value and write the YAML file. - assert type(config_yaml[keys[-1]]) != dict - config_yaml[keys[-1]] = getattr(__builtins__, value_type)(value) - new_yaml = yaml.dump(root_yaml, default_flow_style=False).rstrip() - Path(config_path).write_text(new_yaml) - - task_logger.info(f"Updated: {Path(config_path)}") - - -@click.command(name="standardize") -@click.pass_obj -def manage_standardize(dbgym_cfg): - config_path = dbgym_cfg.path - config_yaml = dbgym_cfg.yaml - - # Write the YAML file. - new_yaml = yaml.dump(config_yaml, default_flow_style=False).rstrip() - Path(config_path).write_text(new_yaml) - - task_logger.info(f"Updated: {Path(config_path)}") - - @click.command("clean") @click.pass_obj @click.option( @@ -98,13 +37,13 @@ def manage_standardize(dbgym_cfg): default="safe", help='The mode to clean the workspace (default="safe"). "aggressive" means "only keep run_*/ folders referenced by a file in symlinks/". "safe" means "in addition to that, recursively keep any run_*/ folders referenced by any symlinks in run_*/ folders we are keeping."', ) -def manage_clean(dbgym_cfg: DBGymConfig, mode: str): +def manage_clean(dbgym_cfg: DBGymConfig, mode: str) -> None: clean_workspace(dbgym_cfg, mode=mode, verbose=True) @click.command("count") @click.pass_obj -def manage_count(dbgym_cfg: DBGymConfig): +def manage_count(dbgym_cfg: DBGymConfig) -> None: num_files = _count_files_in_workspace(dbgym_cfg) print( f"The workspace ({dbgym_cfg.dbgym_workspace_path}) has {num_files} total files/dirs/symlinks." @@ -112,7 +51,7 @@ def manage_count(dbgym_cfg: DBGymConfig): def add_symlinks_in_dpath( - symlinks_stack: List[Path], root_dpath: Path, processed_symlinks: Set[Path] + symlinks_stack: list[Path], root_dpath: Path, processed_symlinks: set[Path] ) -> None: """ Will modify symlinks_stack and processed_symlinks. @@ -127,7 +66,7 @@ def add_symlinks_in_dpath( processed_symlinks.add(file_path) -def _count_files_in_workspace(dbgym_cfg: DBGymConfig) -> int: +def _count_files_in_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig) -> int: """ Counts the number of files (regular file or dir or symlink) in the workspace. """ @@ -146,7 +85,7 @@ def _count_files_in_workspace(dbgym_cfg: DBGymConfig) -> int: return total_count -def clean_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe", verbose=False) -> None: +def clean_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe", verbose: bool=False) -> None: """ Clean all [workspace]/task_runs/run_*/ directories that are not referenced by any "active symlinks". If mode is "aggressive", "active symlinks" means *only* the symlinks directly in [workspace]/symlinks/. @@ -154,9 +93,9 @@ def clean_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe" any symlinks referenced in task_runs/run_*/ directories we have already decided to keep. """ # This stack holds the symlinks that are left to be processed - symlink_fpaths_to_process = [] + symlink_fpaths_to_process: list[Path] = [] # This set holds the symlinks that have already been processed to avoid infinite loops - processed_symlinks = set() + processed_symlinks: set[Path] = set() # 1. Initialize paths to process if dbgym_cfg.dbgym_symlinks_path.exists(): @@ -247,8 +186,5 @@ def clean_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe" ) -manage_group.add_command(manage_show) -manage_group.add_command(manage_write) -manage_group.add_command(manage_standardize) manage_group.add_command(manage_clean) manage_group.add_command(manage_count) From 4273cd8277ca508cbb99b9831f67ec0e51c77185 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:39:33 +0000 Subject: [PATCH 47/60] replaced List/Dict/Set with list/dict/set --- manage/cli.py | 2 -- misc/utils.py | 2 +- tune/protox/agent/agent_env.py | 10 +++++----- tune/protox/agent/buffers.py | 4 ++-- tune/protox/agent/build_trial.py | 10 +++++----- tune/protox/agent/off_policy_algorithm.py | 8 ++++---- tune/protox/agent/policies.py | 6 +++--- tune/protox/agent/replay.py | 2 +- tune/protox/agent/torch_layers.py | 4 ++-- tune/protox/agent/wolp/policies.py | 4 ++-- tune/protox/agent/wolp/wolp.py | 8 ++++---- tune/protox/embedding/loss.py | 6 +++--- tune/protox/embedding/train_all.py | 6 +++--- tune/protox/embedding/trainer.py | 10 +++++----- tune/protox/embedding/vae.py | 20 +++++++++---------- tune/protox/env/lsc/lsc_wrapper.py | 4 ++-- tune/protox/env/mqo/mqo_wrapper.py | 4 ++-- tune/protox/env/pg_env.py | 12 +++++------ tune/protox/env/space/holon_space.py | 14 ++++++------- .../space/latent_space/latent_index_space.py | 4 ++-- .../space/latent_space/latent_knob_space.py | 4 ++-- .../space/latent_space/latent_query_space.py | 4 ++-- tune/protox/env/space/primitive/knob.py | 4 ++-- .../env/space/primitive_space/index_policy.py | 4 ++-- .../env/space/primitive_space/index_space.py | 2 +- tune/protox/env/space/state/structure.py | 2 +- tune/protox/env/space/utils.py | 2 +- .../env/target_reset/target_reset_wrapper.py | 4 ++-- tune/protox/env/types.py | 16 +++++++-------- tune/protox/env/util/execute.py | 4 ++-- tune/protox/env/util/pg_conn.py | 2 +- tune/protox/env/util/reward.py | 4 ++-- tune/protox/env/util/workload_analysis.py | 4 ++-- tune/protox/env/workload.py | 8 ++++---- util/pg.py | 2 +- 35 files changed, 102 insertions(+), 104 deletions(-) diff --git a/manage/cli.py b/manage/cli.py index e0fd0aa9..624b248f 100644 --- a/manage/cli.py +++ b/manage/cli.py @@ -3,10 +3,8 @@ import shutil from itertools import chain from pathlib import Path -from typing import List, Set import click -import yaml from misc.utils import DBGymConfig, get_runs_path_from_workspace_path, get_symlinks_path_from_workspace_path, is_child_path, parent_dpath_of_path diff --git a/misc/utils.py b/misc/utils.py index d68cc233..c434e3fd 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -445,7 +445,7 @@ def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode: str="r") -> IO def extract_from_task_run_fordpath( dbgym_cfg: DBGymConfig, task_run_fordpath: Path -) -> Tuple[Path, str, Path, str]: +) -> tuple[Path, str, Path, str]: """ The task_runs/ folder is organized like task_runs/run_*/[codebase]/[org]/any/path/you/want. This function extracts the [codebase] and [org] components diff --git a/tune/protox/agent/agent_env.py b/tune/protox/agent/agent_env.py index b5af657b..4a69c2ef 100644 --- a/tune/protox/agent/agent_env.py +++ b/tune/protox/agent/agent_env.py @@ -1,6 +1,6 @@ import copy import inspect -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import gymnasium as gym import numpy as np @@ -12,7 +12,7 @@ def __init__(self, env: gym.Env[Any, Any]): super().__init__(env) self.class_attributes = dict(inspect.getmembers(self.__class__)) - def reset(self, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: observations, info = self.env.reset(**kwargs) self._check_val(event="reset", observations=observations) self._observations = observations @@ -20,7 +20,7 @@ def reset(self, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: def step( self, actions: NDArray[np.float32] - ) -> Tuple[Any, float, bool, bool, dict[str, Any]]: + ) -> tuple[Any, float, bool, bool, dict[str, Any]]: self._actions = actions observations, rewards, term, trunc, infos = self.env.step(actions) @@ -50,7 +50,7 @@ def __getattr__(self, name: str) -> Any: return self.getattr_recursive(name) - def _get_all_attributes(self) -> Dict[str, Any]: + def _get_all_attributes(self) -> dict[str, Any]: """Get all (inherited) instance and class attributes :return: all_attributes @@ -97,7 +97,7 @@ def getattr_depth_check(self, name: str, already_found: bool) -> str: def check_array_value( self, name: str, value: NDArray[np.float32] - ) -> List[Tuple[str, str]]: + ) -> list[tuple[str, str]]: """ Check for inf and NaN for a single numpy array. diff --git a/tune/protox/agent/buffers.py b/tune/protox/agent/buffers.py index d2b7e351..d4de74d4 100644 --- a/tune/protox/agent/buffers.py +++ b/tune/protox/agent/buffers.py @@ -12,7 +12,7 @@ class ReplayBufferSamples(NamedTuple): next_observations: th.Tensor dones: th.Tensor rewards: th.Tensor - infos: List[dict[str, Any]] + infos: list[dict[str, Any]] class ReplayBuffer: @@ -68,7 +68,7 @@ def add( action: NDArray[np.float32], reward: float, done: bool, - infos: Dict[str, Any], + infos: dict[str, Any], ) -> None: # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.action_dim)) diff --git a/tune/protox/agent/build_trial.py b/tune/protox/agent/build_trial.py index dde38718..917a6ac2 100644 --- a/tune/protox/agent/build_trial.py +++ b/tune/protox/agent/build_trial.py @@ -61,7 +61,7 @@ def _parse_activation_fn(act_type: str) -> type[nn.Module]: raise ValueError(f"Unsupported activation type {act_type}") -def _get_signal(signal_folder: Union[str, Path]) -> Tuple[int, str]: +def _get_signal(signal_folder: Union[str, Path]) -> tuple[int, str]: MIN_PORT = 5434 MAX_PORT = 5500 @@ -142,7 +142,7 @@ def _build_utilities( tuning_mode: TuningMode, pgport: int, hpo_params: dict[str, Any], -) -> Tuple[Logger, RewardUtility, PostgresConn, Workload]: +) -> tuple[Logger, RewardUtility, PostgresConn, Workload]: logger = Logger( dbgym_cfg, hpo_params["trace"], @@ -201,7 +201,7 @@ def _build_actions( hpo_params: dict[str, Any], workload: Workload, logger: Logger, -) -> Tuple[HolonSpace, LSC]: +) -> tuple[HolonSpace, LSC]: sysknobs = LatentKnobSpace( logger=logger, tables=hpo_params["benchmark_config"]["tables"], @@ -334,7 +334,7 @@ def _build_env( workload: Workload, reward_utility: RewardUtility, logger: Logger, -) -> Tuple[TargetResetWrapper, AgentEnv]: +) -> tuple[TargetResetWrapper, AgentEnv]: env = gym.make( "Postgres-v0", @@ -538,7 +538,7 @@ def build_trial( seed: int, hpo_params: dict[str, Any], ray_trial_id: Optional[str] = None, -) -> Tuple[Logger, TargetResetWrapper, AgentEnv, Wolp, str]: +) -> tuple[Logger, TargetResetWrapper, AgentEnv, Wolp, str]: # The massive trial builder. port, signal = _get_signal(hpo_params["pgconn_info"]["pgbin_path"]) diff --git a/tune/protox/agent/off_policy_algorithm.py b/tune/protox/agent/off_policy_algorithm.py index fd33004e..36567b29 100644 --- a/tune/protox/agent/off_policy_algorithm.py +++ b/tune/protox/agent/off_policy_algorithm.py @@ -43,7 +43,7 @@ def __init__( replay_buffer: ReplayBuffer, learning_starts: int = 100, batch_size: int = 256, - train_freq: Tuple[int, str] = (1, "step"), + train_freq: tuple[int, str] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, seed: Optional[int] = None, @@ -62,7 +62,7 @@ def __init__( # Save train freq parameter, will be converted later to TrainFreq object self.train_freq = self._convert_train_freq(train_freq) - def _convert_train_freq(self, train_freq: Tuple[int, str]) -> TrainFreq: + def _convert_train_freq(self, train_freq: tuple[int, str]) -> TrainFreq: """ Convert `train_freq` parameter (int or tuple) to a TrainFreq object. @@ -91,7 +91,7 @@ def _store_transition( new_obs: NDArray[np.float32], reward: float, dones: bool, - infos: Dict[str, Any], + infos: dict[str, Any], ) -> None: """ Store transition in the replay buffer. @@ -135,7 +135,7 @@ def _sample_action( self, learning_starts: int, action_noise: Optional[ActionNoise] = None, - ) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: + ) -> tuple[NDArray[np.float32], NDArray[np.float32]]: raise NotImplementedError() def collect_rollouts( diff --git a/tune/protox/agent/policies.py b/tune/protox/agent/policies.py index 85fedb8a..3464539d 100644 --- a/tune/protox/agent/policies.py +++ b/tune/protox/agent/policies.py @@ -83,7 +83,7 @@ def __init__( self, observation_space: spaces.Space[Any], action_space: spaces.Space[Any], - net_arch: List[int], + net_arch: list[int], features_dim: int, activation_fn: Type[nn.Module] = nn.ReLU, weight_init: Optional[str] = None, @@ -150,7 +150,7 @@ def __init__( self, observation_space: spaces.Space[Any], action_space: spaces.Space[Any], - net_arch: List[int], + net_arch: list[int], features_dim: int, activation_fn: Type[nn.Module] = nn.ReLU, weight_init: Optional[str] = None, @@ -178,7 +178,7 @@ def __init__( self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]: with th.set_grad_enabled(True): features = self.extract_features(obs) qvalue_input = th.cat([features, actions], dim=1) diff --git a/tune/protox/agent/replay.py b/tune/protox/agent/replay.py index c1284569..63b39284 100644 --- a/tune/protox/agent/replay.py +++ b/tune/protox/agent/replay.py @@ -308,7 +308,7 @@ def _execute_workload_wrapper(actions_info: ActionsInfo) -> tuple[int, int, bool current_step = 0 start_found = False start_time = None - existing_index_acts: Set[IndexAction] = set() + existing_index_acts: set[IndexAction] = set() for line in f: # Keep going until we've found the start. diff --git a/tune/protox/agent/torch_layers.py b/tune/protox/agent/torch_layers.py index 941478c4..1be91a1e 100644 --- a/tune/protox/agent/torch_layers.py +++ b/tune/protox/agent/torch_layers.py @@ -33,14 +33,14 @@ def init_layer( def create_mlp( input_dim: int, output_dim: int, - net_arch: List[int], + net_arch: list[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, weight_init: Optional[str] = None, bias_zero: bool = False, final_layer_adjust: float = 1.0, -) -> List[nn.Module]: +) -> list[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. diff --git a/tune/protox/agent/wolp/policies.py b/tune/protox/agent/wolp/policies.py index 0d1f2a27..906a3750 100644 --- a/tune/protox/agent/wolp/policies.py +++ b/tune/protox/agent/wolp/policies.py @@ -98,7 +98,7 @@ def discriminate( embed_actions: th.Tensor, actions_dim: th.Tensor, env_actions: list[HolonAction], - ) -> Tuple[list[HolonAction], th.Tensor]: + ) -> tuple[list[HolonAction], th.Tensor]: states_tile = states.repeat_interleave(actions_dim, dim=0) if use_target: next_q_values = th.cat( @@ -140,7 +140,7 @@ def wolp_act( action_noise: Optional[Union[ActionNoise, th.Tensor]] = None, neighbor_parameters: NeighborParameters = DEFAULT_NEIGHBOR_PARAMETERS, random_act: bool = False, - ) -> Tuple[list[HolonAction], th.Tensor]: + ) -> tuple[list[HolonAction], th.Tensor]: # Get the tensor representation. start_time = time.time() if not isinstance(states, th.Tensor): diff --git a/tune/protox/agent/wolp/wolp.py b/tune/protox/agent/wolp/wolp.py index ba519258..6b4f5c8e 100644 --- a/tune/protox/agent/wolp/wolp.py +++ b/tune/protox/agent/wolp/wolp.py @@ -47,12 +47,12 @@ def __init__( replay_buffer: ReplayBuffer, learning_starts: int = 100, batch_size: int = 100, - train_freq: Tuple[int, str] = (1, "episode"), + train_freq: tuple[int, str] = (1, "episode"), gradient_steps: int = -1, action_noise: Optional[ActionNoise] = None, target_action_noise: Optional[ActionNoise] = None, seed: Optional[int] = None, - neighbor_parameters: Dict[str, Any] = {}, + neighbor_parameters: dict[str, Any] = {}, ray_trial_id: Optional[str] = None, ): super().__init__( @@ -77,7 +77,7 @@ def _store_transition( new_obs: NDArray[np.float32], reward: float, dones: bool, - infos: Dict[str, Any], + infos: dict[str, Any], ) -> None: """ Store transition in the replay buffer. @@ -124,7 +124,7 @@ def _sample_action( self, learning_starts: int, action_noise: Optional[ActionNoise] = None, - ) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: + ) -> tuple[NDArray[np.float32], NDArray[np.float32]]: """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, diff --git a/tune/protox/embedding/loss.py b/tune/protox/embedding/loss.py index b3f28843..ec34473e 100644 --- a/tune/protox/embedding/loss.py +++ b/tune/protox/embedding/loss.py @@ -24,11 +24,11 @@ def get_loss(distance_fn: str) -> nn.Module: def get_bias_fn( config: dict[str, Any] ) -> Callable[ - [torch.Tensor, torch.Tensor], Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + [torch.Tensor, torch.Tensor], Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] ]: def bias_fn( data: torch.Tensor, labels: torch.Tensor - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: red_index = COST_COLUMNS.index(config["cost_reduction_type"]) distance_scale = config["distance_scale"] if distance_scale == "auto": @@ -74,7 +74,7 @@ def _distance_cost( targets: torch.Tensor, bias: Callable[ [torch.Tensor, torch.Tensor], - Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], + Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], ], output_scale: float, ) -> Any: diff --git a/tune/protox/embedding/train_all.py b/tune/protox/embedding/train_all.py index e8358387..66062593 100644 --- a/tune/protox/embedding/train_all.py +++ b/tune/protox/embedding/train_all.py @@ -46,7 +46,7 @@ from tune.protox.env.workload import Workload -def fetch_vae_parameters_from_workload(w: Workload, ntables: int) -> Tuple[int, int]: +def fetch_vae_parameters_from_workload(w: Workload, ntables: int) -> tuple[int, int]: max_indexable = w.max_indexable() max_cat_features = max( ntables, max_indexable + 1 @@ -59,7 +59,7 @@ def fetch_index_parameters( dbgym_cfg: DBGymConfig, data: dict[str, Any], workload_path: Path, -) -> Tuple[int, int, TableAttrListMap, dict[TableColTuple, int]]: +) -> tuple[int, int, TableAttrListMap, dict[TableColTuple, int]]: tables = data["tables"] attributes = data["attributes"] query_spec = data["query_spec"] @@ -94,7 +94,7 @@ def load_input_data( max_attrs: int, require_cost: bool, seed: int, -) -> Tuple[TensorDataset, Any, Any, Optional[TensorDataset], int]: +) -> tuple[TensorDataset, Any, Any, Optional[TensorDataset], int]: # Load the input data. columns = [] columns += ["tbl_index", "idx_class"] diff --git a/tune/protox/embedding/trainer.py b/tune/protox/embedding/trainer.py index 19648aa5..e259f9c9 100644 --- a/tune/protox/embedding/trainer.py +++ b/tune/protox/embedding/trainer.py @@ -26,7 +26,7 @@ def __init__( self.elem_per_class = 0 assert self.batch_size > 0 - def compute(self) -> Tuple[dict[int, Tuple[int, NDArray[Any]]], int, int]: + def compute(self) -> tuple[dict[int, tuple[int, NDArray[Any]]], int, int]: r = {} for c in range(self.max_class): lc = np.argwhere(self.labels == c) @@ -80,7 +80,7 @@ def __init__( bias_fn: Optional[ Callable[ [torch.Tensor, torch.Tensor], - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], ] ], *args: Any, @@ -90,7 +90,7 @@ def __init__( self.failed = False self.fail_msg: Optional[str] = None self.fail_data: Optional[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ] = None self.disable_tqdm = disable_tqdm self.bias_fn = bias_fn @@ -117,7 +117,7 @@ def maybe_get_vae_loss( ) return 0 - def calculate_loss(self, curr_batch: Tuple[torch.Tensor, torch.Tensor]) -> None: + def calculate_loss(self, curr_batch: tuple[torch.Tensor, torch.Tensor]) -> None: data, labels = curr_batch if labels.shape[1] == 1: # Flatten labels if it's a class. @@ -232,7 +232,7 @@ def train(self, start_epoch: int = 1, num_epochs: int = 1) -> None: def compute_embeddings(self, base_output: Any) -> None: assert False - def get_batch(self) -> Tuple[torch.Tensor, torch.Tensor]: + def get_batch(self) -> tuple[torch.Tensor, torch.Tensor]: self.dataloader_iter, curr_batch = c_f.try_next_on_generator(self.dataloader_iter, self.dataloader) # type: ignore data, labels = self.data_and_label_getter(curr_batch) return data, labels diff --git a/tune/protox/embedding/vae.py b/tune/protox/embedding/vae.py index eb03a3f0..9040d49c 100644 --- a/tune/protox/embedding/vae.py +++ b/tune/protox/embedding/vae.py @@ -9,10 +9,10 @@ def gen_vae_collate( max_categorical: int, infer: bool = False -) -> Callable[[list[Any]], Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: +) -> Callable[[list[Any]], Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: def vae_collate( batch: list[Any], - ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if infer: x = torch.as_tensor(batch).type(torch.int64) else: @@ -120,7 +120,7 @@ def forward( embeddings: torch.Tensor, labels: Any = None, indices_tuple: Any = None, - ref_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ref_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ref_labels: Any = None, is_eval: bool = False, ) -> Any: @@ -149,7 +149,7 @@ def compute_loss( preds: torch.Tensor, unused0: Any, unused1: Any, - tdata: Optional[Tuple[torch.Tensor, torch.Tensor]], + tdata: Optional[tuple[torch.Tensor, torch.Tensor]], *args: Any, **kwargs: Any ) -> Any: @@ -353,16 +353,16 @@ def get_collate(self) -> Callable[[torch.Tensor], torch.Tensor]: def forward( self, x: torch.Tensor, - bias: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, - ) -> Union[Tuple[torch.Tensor, torch.Tensor, bool], Tuple[torch.Tensor, bool]]: + bias: Optional[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> Union[tuple[torch.Tensor, torch.Tensor, bool], tuple[torch.Tensor, bool]]: return self._compute(x, bias=bias, require_full=True) def latents( self, x: torch.Tensor, - bias: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + bias: Optional[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]] = None, require_full: bool = False, - ) -> Tuple[torch.Tensor, bool]: + ) -> tuple[torch.Tensor, bool]: rets = self._compute(x, bias=bias, require_full=False) assert len(rets) == 2 return rets[0], rets[1] @@ -370,9 +370,9 @@ def latents( def _compute( self, x: torch.Tensor, - bias: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + bias: Optional[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]] = None, require_full: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor, bool], Tuple[torch.Tensor, bool]]: + ) -> Union[tuple[torch.Tensor, torch.Tensor, bool], tuple[torch.Tensor, bool]]: latents: torch.Tensor = self.encoder(x) latents = latents * self.output_scale diff --git a/tune/protox/env/lsc/lsc_wrapper.py b/tune/protox/env/lsc/lsc_wrapper.py index 5f5d464c..5d4ff5a5 100644 --- a/tune/protox/env/lsc/lsc_wrapper.py +++ b/tune/protox/env/lsc/lsc_wrapper.py @@ -14,7 +14,7 @@ def __init__(self, lsc: LSC, env: gym.Env[Any, Any], logger: Optional[Logger]): self.lsc = lsc self.logger = logger - def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: + def reset(self, *args: Any, **kwargs: Any) -> tuple[Any, dict[str, Any]]: state, info = self.env.reset(*args, **kwargs) self.lsc.reset() @@ -27,7 +27,7 @@ def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: def step( self, *args: Any, **kwargs: Any - ) -> Tuple[Any, float, bool, bool, dict[str, Any]]: + ) -> tuple[Any, float, bool, bool, dict[str, Any]]: state, reward, term, trunc, info = self.env.step(*args, **kwargs) # Remember the LSC when we considered this action. diff --git a/tune/protox/env/mqo/mqo_wrapper.py b/tune/protox/env/mqo/mqo_wrapper.py index acd5456f..50b9dc39 100644 --- a/tune/protox/env/mqo/mqo_wrapper.py +++ b/tune/protox/env/mqo/mqo_wrapper.py @@ -199,7 +199,7 @@ def _update_best_observed( def step( # type: ignore self, action: HolonAction, - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, float, bool, bool, EnvInfoDict]: # Step based on the "global" action. assert isinstance(self.unwrapped, PostgresEnv) success, info = self.unwrapped.step_before_execution(action) @@ -331,7 +331,7 @@ def transmute( return self.unwrapped.step_post_execute(success, action, info) - def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, EnvInfoDict]: # type: ignore + def reset(self, *args: Any, **kwargs: Any) -> tuple[Any, EnvInfoDict]: # type: ignore assert isinstance(self.unwrapped, PostgresEnv) # First have to shift to the new state. state, info = self.unwrapped.reset(*args, **kwargs) diff --git a/tune/protox/env/pg_env.py b/tune/protox/env/pg_env.py index fe6c5ba1..5f94e587 100644 --- a/tune/protox/env/pg_env.py +++ b/tune/protox/env/pg_env.py @@ -79,7 +79,7 @@ def _restore_last_snapshot(self) -> None: @time_record("reset") def reset( # type: ignore self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None - ) -> Tuple[Any, EnvInfoDict]: + ) -> tuple[Any, EnvInfoDict]: reset_start = time.time() if self.logger: self.logger.get_logger(__name__).info( @@ -213,7 +213,7 @@ def reset( # type: ignore return self.current_state, info @time_record("step_before_execution") - def step_before_execution(self, action: HolonAction) -> Tuple[bool, EnvInfoDict]: + def step_before_execution(self, action: HolonAction) -> tuple[bool, EnvInfoDict]: # Log the action in debug mode. if self.logger: self.logger.get_logger(__name__).debug( @@ -249,9 +249,9 @@ def step_before_execution(self, action: HolonAction) -> Tuple[bool, EnvInfoDict] def step_execute( self, setup_success: bool, - all_holon_action_variations: list[Tuple[str, HolonAction]], + all_holon_action_variations: list[tuple[str, HolonAction]], info: EnvInfoDict, - ) -> Tuple[bool, EnvInfoDict]: + ) -> tuple[bool, EnvInfoDict]: if setup_success: assert isinstance(self.observation_space, StateSpace) assert isinstance(self.action_space, HolonSpace) @@ -320,7 +320,7 @@ def step_post_execute( action: HolonAction, info: EnvInfoDict, soft: bool = False, - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, float, bool, bool, EnvInfoDict]: if self.workload.oltp_workload and self.horizon > 1: # If horizon = 1, then we're going to reset anyways. So easier to just untar the original archive. # Restore the crisp and clean snapshot. @@ -366,7 +366,7 @@ def step_post_execute( def step( # type: ignore self, action: HolonAction - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, float, bool, bool, EnvInfoDict]: assert self.tuning_mode != TuningMode.REPLAY success, info = self.step_before_execution(action) success, info = self.step_execute(success, [("PerQuery", action)], info) diff --git a/tune/protox/env/space/holon_space.py b/tune/protox/env/space/holon_space.py index cc66ab4c..870ecedd 100644 --- a/tune/protox/env/space/holon_space.py +++ b/tune/protox/env/space/holon_space.py @@ -81,9 +81,9 @@ def __init__( self.space_dims: Optional[list[int]] = None self.logger = logger - def get_spaces(self) -> list[Tuple[str, HolonSubSpace]]: + def get_spaces(self) -> list[tuple[str, HolonSubSpace]]: r = cast( - list[Tuple[str, HolonSubSpace]], + list[tuple[str, HolonSubSpace]], [(s.name, s) for s in self.spaces if hasattr(s, "name")], ) assert len(r) == 3 @@ -98,7 +98,7 @@ def null_action(self, sc: HolonStateContainer) -> HolonAction: def split_action( self, action: HolonAction - ) -> list[Tuple[HolonSubSpace, HolonSubAction]]: + ) -> list[tuple[HolonSubSpace, HolonSubAction]]: return [ (cast(LatentKnobSpace, self.spaces[0]), action[0]), (cast(LatentIndexSpace, self.spaces[1]), action[1]), @@ -230,7 +230,7 @@ def neighborhood( self, raw_action: ProtoAction, neighbor_parameters: NeighborParameters = DEFAULT_NEIGHBOR_PARAMETERS, - ) -> Tuple[list[HolonAction], ProtoAction, torch.Tensor]: + ) -> tuple[list[HolonAction], ProtoAction, torch.Tensor]: env_acts = [] emb_acts: list[torch.Tensor] = [] ndims = [] @@ -329,7 +329,7 @@ def generate_state_container( prev_state_container: Optional[HolonStateContainer], action: Optional[HolonAction], connection: Connection[Any], - queries: dict[str, list[Tuple[QueryType, str]]], + queries: dict[str, list[tuple[QueryType, str]]], ) -> HolonStateContainer: t = tuple( s.generate_state_container( @@ -346,7 +346,7 @@ def generate_state_container( def generate_action_plan( self, action: HolonAction, state_container: HolonStateContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: outputs = [ space.generate_action_plan(action[i], state_container[i], **kwargs) for i, space in enumerate(self.spaces) @@ -359,7 +359,7 @@ def generate_action_plan( def generate_plan_from_config( self, config: HolonStateContainer, sc: HolonStateContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: outputs = [ space.generate_delta_action_plan(config[i], sc[i], **kwargs) for i, space in enumerate(self.spaces) diff --git a/tune/protox/env/space/latent_space/latent_index_space.py b/tune/protox/env/space/latent_space/latent_index_space.py index 33d59466..f92c98f7 100644 --- a/tune/protox/env/space/latent_space/latent_index_space.py +++ b/tune/protox/env/space/latent_space/latent_index_space.py @@ -250,7 +250,7 @@ def generate_state_container( def generate_action_plan( self, action: IndexSpaceRawSample, sc: IndexSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: assert check_subspace(self, action) sql_commands = [] @@ -277,7 +277,7 @@ def generate_action_plan( def generate_delta_action_plan( self, action: IndexSpaceContainer, sc: IndexSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: assert isinstance(action, list) acts = [] sql_commands = [] diff --git a/tune/protox/env/space/latent_space/latent_knob_space.py b/tune/protox/env/space/latent_space/latent_knob_space.py index 6d1a97ea..caa923ee 100644 --- a/tune/protox/env/space/latent_space/latent_knob_space.py +++ b/tune/protox/env/space/latent_space/latent_knob_space.py @@ -181,7 +181,7 @@ def generate_state_container( def generate_action_plan( self, action: KnobSpaceAction, sc: KnobSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: config_changes = [] sql_commands = [] require_cleanup = False @@ -235,5 +235,5 @@ def generate_action_plan( def generate_delta_action_plan( self, action: KnobSpaceAction, sc: KnobSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: return self.generate_action_plan(action, sc, **kwargs) diff --git a/tune/protox/env/space/latent_space/latent_query_space.py b/tune/protox/env/space/latent_space/latent_query_space.py index 0bfa40b4..a2668206 100644 --- a/tune/protox/env/space/latent_space/latent_query_space.py +++ b/tune/protox/env/space/latent_space/latent_query_space.py @@ -41,12 +41,12 @@ def generate_state_container( def generate_action_plan( self, action: QuerySpaceAction, sc: QuerySpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: return [], [] def generate_delta_action_plan( self, action: QuerySpaceAction, sc: QuerySpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: return [], [] def extract_query(self, action: QuerySpaceAction) -> QuerySpaceKnobAction: diff --git a/tune/protox/env/space/primitive/knob.py b/tune/protox/env/space/primitive/knob.py index 7905bf70..a09ce942 100644 --- a/tune/protox/env/space/primitive/knob.py +++ b/tune/protox/env/space/primitive/knob.py @@ -27,7 +27,7 @@ def full_knob_name( return knob_name -def _parse_setting_dtype(type_str: str) -> Tuple[SettingType, Any]: +def _parse_setting_dtype(type_str: str) -> tuple[SettingType, Any]: return { "boolean": (SettingType.BOOLEAN, np.int32), "integer": (SettingType.INTEGER, np.int32), @@ -229,7 +229,7 @@ def _flatdim_knob(space: Knob) -> int: return 1 -def _categorical_elems(type_str: str) -> Tuple[SettingType, int]: +def _categorical_elems(type_str: str) -> tuple[SettingType, int]: return { "scanmethod_enum_categorical": (SettingType.SCANMETHOD_ENUM_CATEGORICAL, 2), }[type_str] diff --git a/tune/protox/env/space/primitive_space/index_policy.py b/tune/protox/env/space/primitive_space/index_policy.py index 99390e57..dd03d209 100644 --- a/tune/protox/env/space/primitive_space/index_policy.py +++ b/tune/protox/env/space/primitive_space/index_policy.py @@ -42,7 +42,7 @@ def __init__( self.index_space_aux_include = index_space_aux_include def spaces(self, seed: int) -> Sequence[spaces.Space[Any]]: - aux: List[spaces.Space[Any]] = [ + aux: list[spaces.Space[Any]] = [ # One-hot encoding for the tables. spaces.Discrete(self.num_tables, seed=seed), # Ordering. Note that we use the postgres style ordinal notation. 0 is illegal/end-of-index. @@ -67,7 +67,7 @@ def spaces(self, seed: int) -> Sequence[spaces.Space[Any]]: ) ] - return cast(List[spaces.Space[Any]], aux_type + aux + aux_include) + return cast(list[spaces.Space[Any]], aux_type + aux + aux_include) def to_action(self, act: IndexSpaceRawSample) -> IndexAction: # First index is the index type. diff --git a/tune/protox/env/space/primitive_space/index_space.py b/tune/protox/env/space/primitive_space/index_space.py index ca4b7a06..f8e8ff41 100644 --- a/tune/protox/env/space/primitive_space/index_space.py +++ b/tune/protox/env/space/primitive_space/index_space.py @@ -107,7 +107,7 @@ def null_action(self) -> IndexSpaceRawSample: action[0] = 1.0 return self.policy.sample_dist(action, self.np_random, sample_num_columns=False) - def to_jsonable(self, sample_n) -> List[str]: # type: ignore + def to_jsonable(self, sample_n) -> list[str]: # type: ignore # Emit the representation of an index. ias = [self.to_action(sample) for sample in sample_n] return [ia.__repr__() for ia in ias] diff --git a/tune/protox/env/space/state/structure.py b/tune/protox/env/space/state/structure.py index af29c1cf..c5b2ab19 100644 --- a/tune/protox/env/space/state/structure.py +++ b/tune/protox/env/space/state/structure.py @@ -32,7 +32,7 @@ def __init__( self.normalize = normalize if self.normalize: - self.internal_spaces: Dict[str, gym.spaces.Space[Any]] = { + self.internal_spaces: dict[str, gym.spaces.Space[Any]] = { k: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(s.critic_dim(),)) for k, s in action_space.get_spaces() } diff --git a/tune/protox/env/space/utils.py b/tune/protox/env/space/utils.py index c1f79a4c..734a4ca8 100644 --- a/tune/protox/env/space/utils.py +++ b/tune/protox/env/space/utils.py @@ -224,7 +224,7 @@ def fetch_server_knobs( def fetch_server_indexes( connection: Connection[Any], tables: list[str] -) -> typing.Tuple[TableAttrListMap, ServerTableIndexMetadata]: +) -> typing.tuple[TableAttrListMap, ServerTableIndexMetadata]: rel_metadata = TableAttrListMap({t: [] for t in tables}) existing_indexes = ServerTableIndexMetadata({}) with connection.cursor(row_factory=dict_row) as cursor: diff --git a/tune/protox/env/target_reset/target_reset_wrapper.py b/tune/protox/env/target_reset/target_reset_wrapper.py index 800ec60a..edfcf520 100644 --- a/tune/protox/env/target_reset/target_reset_wrapper.py +++ b/tune/protox/env/target_reset/target_reset_wrapper.py @@ -36,7 +36,7 @@ def _get_state(self) -> HolonStateContainer: def step( # type: ignore self, *args: Any, **kwargs: Any - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, float, bool, bool, EnvInfoDict]: """Steps through the environment, normalizing the rewards returned.""" obs, rews, terms, truncs, infos = self.env.step(*args, **kwargs) query_metric_data = infos.get("query_metric_data", None) @@ -81,7 +81,7 @@ def step( # type: ignore ] return obs, rews, terms, truncs, infos - def reset(self, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: if len(self.tracked_states) == 0: # First time. state, info = self.env.reset(**kwargs) diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index ec7d22f6..d79ec8db 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -75,12 +75,12 @@ class ServerIndexMetadata(TypedDict, total=False): QuerySpaceContainer: TypeAlias = KnobSpaceContainer # ([idx_type], [table_encoding], [key1_encoding], ... [key#_encoding], [include_mask]) -IndexSpaceRawSample = NewType("IndexSpaceRawSample", Tuple[Any, ...]) +IndexSpaceRawSample = NewType("IndexSpaceRawSample", tuple[Any, ...]) # [IndexAction(index1), ...] IndexSpaceContainer = NewType("IndexSpaceContainer", list["IndexAction"]) # (table_name, column_name) -TableColTuple = NewType("TableColTuple", Tuple[str, str]) +TableColTuple = NewType("TableColTuple", tuple[str, str]) # {table: [att1, att2, ...], ...} TableAttrListMap = NewType("TableAttrListMap", dict[str, list[str]]) @@ -91,7 +91,7 @@ class ServerIndexMetadata(TypedDict, total=False): # {table: set[ (att1, att3), (att3, att4), ... ], ...} # This maps a table to a set of attributes accessed together. TableAttrAccessSetsMap = NewType( - "TableAttrAccessSetsMap", dict[str, set[Tuple[str, ...]]] + "TableAttrAccessSetsMap", dict[str, set[tuple[str, ...]]] ) # {qid: {table: scan_method, ...}, ...} @@ -101,11 +101,11 @@ class ServerIndexMetadata(TypedDict, total=False): # {qid: {table: [alias1, alias2, ...], ...}, ...} QueryTableAliasMap = NewType("QueryTableAliasMap", dict[str, TableAliasMap]) # {qid: [(query_type1, query_str1), (query_type2, query_str2), ...], ...} -QueryMap = NewType("QueryMap", dict[str, list[Tuple[QueryType, str]]]) +QueryMap = NewType("QueryMap", dict[str, list[tuple[QueryType, str]]]) HolonAction = NewType( "HolonAction", - Tuple[ + tuple[ KnobSpaceAction, IndexSpaceRawSample, QuerySpaceAction, @@ -114,7 +114,7 @@ class ServerIndexMetadata(TypedDict, total=False): HolonStateContainer = NewType( "HolonStateContainer", - Tuple[ + tuple[ KnobSpaceContainer, IndexSpaceContainer, QuerySpaceContainer, @@ -167,7 +167,7 @@ class QuerySpec(TypedDict, total=False): class ActionsInfo(TypedDict): - all_holon_action_variations: list[Tuple[str, HolonAction]] + all_holon_action_variations: list[tuple[str, HolonAction]] best_observed_holon_action: Optional[HolonAction] @@ -187,7 +187,7 @@ class EnvInfoDict(TypedDict, total=False): prior_pgconf: Optional[Union[str, Path]] # Changes made to the DBMS during this step. - attempted_changes: Tuple[list[str], list[str]] + attempted_changes: tuple[list[str], list[str]] # Metric of this step. metric: Optional[float] diff --git a/tune/protox/env/util/execute.py b/tune/protox/env/util/execute.py index 6ec5d695..fbbe9a4c 100644 --- a/tune/protox/env/util/execute.py +++ b/tune/protox/env/util/execute.py @@ -36,7 +36,7 @@ def _time_query( connection: psycopg.Connection[Any], query: str, timeout: float, -) -> Tuple[float, bool, Any]: +) -> tuple[float, bool, Any]: did_time_out = False has_explain = "EXPLAIN" in query explain_data = None @@ -77,7 +77,7 @@ def _acquire_metrics_around_query( query: str, query_timeout: float = 0.0, observation_space: Optional[StateSpace] = None, -) -> Tuple[float, bool, Any, Any]: +) -> tuple[float, bool, Any, Any]: _force_statement_timeout(connection, 0) if observation_space and observation_space.require_metrics(): initial_metrics = observation_space.construct_online(connection) diff --git a/tune/protox/env/util/pg_conn.py b/tune/protox/env/util/pg_conn.py index cc89d722..3faf66a6 100644 --- a/tune/protox/env/util/pg_conn.py +++ b/tune/protox/env/util/pg_conn.py @@ -302,7 +302,7 @@ def _set_up_boot( self.logger.get_logger(__name__).debug("Set up boot") @time_record("psql") - def psql(self, sql: str) -> Tuple[int, Optional[str]]: + def psql(self, sql: str) -> tuple[int, Optional[str]]: low_sql = sql.lower() def cancel_fn(conn_str: str) -> None: diff --git a/tune/protox/env/util/reward.py b/tune/protox/env/util/reward.py index ba01b8a0..9b5046d7 100644 --- a/tune/protox/env/util/reward.py +++ b/tune/protox/env/util/reward.py @@ -52,7 +52,7 @@ def set_relative_baseline( def parse_tps_avg_p99_for_metric( self, parent: Union[Path, str] - ) -> Tuple[float, float, float]: + ) -> tuple[float, float, float]: files = [f for f in Path(parent).rglob("*.summary.json")] assert len(files) == 1 @@ -99,7 +99,7 @@ def __call__( metric: Optional[float] = None, update: bool = True, did_error: bool = False, - ) -> Tuple[float, float]: + ) -> tuple[float, float]: # TODO: we need to get the memory consumption of indexes. if the index usage # exceeds the limit, then kill the reward function. may also want to penalize diff --git a/tune/protox/env/util/workload_analysis.py b/tune/protox/env/util/workload_analysis.py index 9db06e5e..7df507bf 100644 --- a/tune/protox/env/util/workload_analysis.py +++ b/tune/protox/env/util/workload_analysis.py @@ -72,7 +72,7 @@ def extract_aliases(stmts: pglast.ast.Node) -> TableAliasMap: def extract_sqltypes( stmts: pglast.ast.Node, pid: Optional[int] -) -> list[Tuple[QueryType, str]]: +) -> list[tuple[QueryType, str]]: sqls = [] for stmt in stmts: sql_type = QueryType.UNKNOWN @@ -114,7 +114,7 @@ def extract_columns( tables: list[str], all_attributes: AttrTableListMap, query_aliases: TableAliasMap, -) -> Tuple[TableAttrSetMap, list[TableColTuple]]: +) -> tuple[TableAttrSetMap, list[TableColTuple]]: tbl_col_usages: TableAttrSetMap = TableAttrSetMap({t: set() for t in tables}) def traverse_extract_columns( diff --git a/tune/protox/env/workload.py b/tune/protox/env/workload.py index 5ec72f7e..e0242d65 100644 --- a/tune/protox/env/workload.py +++ b/tune/protox/env/workload.py @@ -62,7 +62,7 @@ def _open_for_reading( def _crunch( self, all_attributes: AttrTableListMap, - sqls: list[Tuple[str, Path, float]], + sqls: list[tuple[str, Path, float]], pid: Optional[int], query_spec: QuerySpec, ) -> None: @@ -356,7 +356,7 @@ def execute_workload( workload_qdir: Optional[tuple[Path, Path]] = None, blocklist: list[str] = [], first: bool = False, - ) -> Tuple[int, bool, dict[str, Any]]: + ) -> tuple[int, bool, dict[str, Any]]: this_execution_workload_timeout = ( self.workload_timeout if not override_workload_timeout @@ -378,7 +378,7 @@ def execute_workload( ][0], ) ql_knobs = cast( - list[Tuple[LatentQuerySpace, QuerySpaceAction]], + list[tuple[LatentQuerySpace, QuerySpaceAction]], [ [ (t, v) @@ -654,7 +654,7 @@ def execute( reset_metrics: Optional[dict[str, BestQueryRun]] = None, update: bool = True, first: bool = False, - ) -> Tuple[bool, float, float, Union[str, Path], bool, dict[str, BestQueryRun]]: + ) -> tuple[bool, float, float, Union[str, Path], bool, dict[str, BestQueryRun]]: success = True if self.logger: self.logger.get_logger(__name__).info("Starting to run benchmark...") diff --git a/util/pg.py b/util/pg.py index 9e08f07e..2cf4506d 100644 --- a/util/pg.py +++ b/util/pg.py @@ -20,7 +20,7 @@ def conn_execute(conn: Connection, sql: str) -> CursorResult[Any]: return conn.execute(text(sql)) -def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> List[str]: +def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> list[str]: with open_and_save(dbgym_cfg, filepath) as f: lines: list[str] = [] for line in f: From e62f612084a925815487551a2c0e417e260ba8fa Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:39:44 +0000 Subject: [PATCH 48/60] reformatted --- benchmark/tpch/load_info.py | 1 + manage/cli.py | 12 +- manage/tests/test_clean.py | 532 ++++++++++++++--------- misc/utils.py | 24 +- scripts/read_parquet.py | 2 +- tune/protox/agent/base_class.py | 4 +- tune/protox/agent/build_trial.py | 9 +- tune/protox/agent/hpo.py | 4 +- tune/protox/agent/replay.py | 6 +- tune/protox/agent/wolp/policies.py | 8 +- tune/protox/embedding/train.py | 32 +- tune/protox/embedding/train_args.py | 2 +- tune/protox/env/logger.py | 2 +- tune/protox/env/mqo/mqo_wrapper.py | 2 +- tune/protox/env/pg_env.py | 10 +- tune/protox/env/types.py | 2 +- tune/protox/tests/test_index_space.py | 10 +- tune/protox/tests/test_workload.py | 6 +- tune/protox/tests/test_workload_utils.py | 136 +++--- util/pg.py | 8 +- util/shell.py | 9 +- 21 files changed, 478 insertions(+), 343 deletions(-) diff --git a/benchmark/tpch/load_info.py b/benchmark/tpch/load_info.py index e678c6be..1076c1f7 100644 --- a/benchmark/tpch/load_info.py +++ b/benchmark/tpch/load_info.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Optional + from dbms.load_info_base_class import LoadInfoBaseClass from misc.utils import DBGymConfig, get_scale_factor_string diff --git a/manage/cli.py b/manage/cli.py index 624b248f..af839570 100644 --- a/manage/cli.py +++ b/manage/cli.py @@ -6,7 +6,13 @@ import click -from misc.utils import DBGymConfig, get_runs_path_from_workspace_path, get_symlinks_path_from_workspace_path, is_child_path, parent_dpath_of_path +from misc.utils import ( + DBGymConfig, + get_runs_path_from_workspace_path, + get_symlinks_path_from_workspace_path, + is_child_path, + parent_dpath_of_path, +) task_logger = logging.getLogger("task") task_logger.setLevel(logging.INFO) @@ -83,7 +89,9 @@ def _count_files_in_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig) -> int: return total_count -def clean_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe", verbose: bool=False) -> None: +def clean_workspace( + dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe", verbose: bool = False +) -> None: """ Clean all [workspace]/task_runs/run_*/ directories that are not referenced by any "active symlinks". If mode is "aggressive", "active symlinks" means *only* the symlinks directly in [workspace]/symlinks/. diff --git a/manage/tests/test_clean.py b/manage/tests/test_clean.py index 27fcc305..20beefbf 100644 --- a/manage/tests/test_clean.py +++ b/manage/tests/test_clean.py @@ -2,14 +2,12 @@ import logging import os import shutil -from typing import Any, NewType, cast import unittest from pathlib import Path +from typing import Any, NewType, cast from manage.cli import MockDBGymConfig, clean_workspace -from misc.utils import ( - path_exists_dont_follow_symlinks, -) +from misc.utils import path_exists_dont_follow_symlinks # This is here instead of on `if __name__ == "__main__"` because we often run individual tests, which # does not go through the `if __name__ == "__main__"` codepath. @@ -25,6 +23,7 @@ class CleanTests(unittest.TestCase): I deemed "clean" important enough to write extensive unit tests for because a bug could lead to losing important files. """ + scratchspace_path: Path = Path() @staticmethod @@ -37,7 +36,11 @@ def create_structure_internal( if isinstance(content, dict): # Directory full_path.mkdir(parents=True, exist_ok=True) - create_structure_internal(root_path, full_path, FilesystemStructure(cast(dict[str, Any], content))) + create_structure_internal( + root_path, + full_path, + FilesystemStructure(cast(dict[str, Any], content)), + ) elif isinstance(content, tuple) and content[0] == "file": assert len(content) == 1 full_path.touch() @@ -66,7 +69,11 @@ def verify_structure_internal( if not new_cur_path.is_dir(): logging.debug(f"expected {new_cur_path} to be a directory") return False - if not verify_structure_internal(root_path, new_cur_path, FilesystemStructure(cast(dict[str, Any], item))): + if not verify_structure_internal( + root_path, + new_cur_path, + FilesystemStructure(cast(dict[str, Any], item)), + ): return False elif isinstance(item, tuple) and item[0] == "file": if not new_cur_path.is_file(): @@ -105,16 +112,19 @@ def verify_structure_internal( @staticmethod def make_workspace_structure( - symlinks_structure: FilesystemStructure, task_runs_structure: FilesystemStructure + symlinks_structure: FilesystemStructure, + task_runs_structure: FilesystemStructure, ) -> FilesystemStructure: """ This function exists so that it's easier to refactor the tests in case we ever change how the workspace is organized. """ - return FilesystemStructure({ - "symlinks": symlinks_structure, - "task_runs": task_runs_structure, - }) + return FilesystemStructure( + { + "symlinks": symlinks_structure, + "task_runs": task_runs_structure, + } + ) @classmethod def setUpClass(cls) -> None: @@ -129,12 +139,14 @@ def tearDown(self) -> None: shutil.rmtree(self.scratchspace_path) def test_structure_helpers(self) -> None: - structure = FilesystemStructure({ - "dir1": {"file1.txt": ("file",), "dir2": {"file2.txt": ("file",)}}, - "dir3": {"nested_link_to_dir1": ("symlink", "dir1")}, - "link_to_dir1": ("symlink", "dir1"), - "link_to_file2": ("symlink", "dir1/dir2/file2.txt"), - }) + structure = FilesystemStructure( + { + "dir1": {"file1.txt": ("file",), "dir2": {"file2.txt": ("file",)}}, + "dir3": {"nested_link_to_dir1": ("symlink", "dir1")}, + "link_to_dir1": ("symlink", "dir1"), + "link_to_file2": ("symlink", "dir1/dir2/file2.txt"), + } + ) CleanTests.create_structure(self.scratchspace_path, structure) self.assertTrue(CleanTests.verify_structure(self.scratchspace_path, structure)) @@ -221,7 +233,9 @@ def test_no_symlinks_dir_and_no_task_runs_dir(self) -> None: ) def test_no_symlinks_dir_and_yes_task_runs_dir(self) -> None: - starting_structure = FilesystemStructure({"task_runs": {"file1.txt": ("file",)}}) + starting_structure = FilesystemStructure( + {"task_runs": {"file1.txt": ("file",)}} + ) ending_structure = FilesystemStructure({"task_runs": {}}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path)) @@ -275,12 +289,18 @@ def test_no_links_in_symlinks(self) -> None: ) def test_link_to_file_directly_in_task_runs(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/file1.txt")}) - starting_task_runs_structure = FilesystemStructure({"file1.txt": ("file",), "file2.txt": ("file",)}) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + {"file1.txt": ("file",), "file2.txt": ("file",)} + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/file1.txt")}) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/file1.txt")} + ) ending_task_runs_structure = FilesystemStructure({"file1.txt": ("file",)}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure @@ -293,16 +313,24 @@ def test_link_to_file_directly_in_task_runs(self) -> None: ) def test_link_to_dir_directly_in_task_runs(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"file1.txt": ("file",)}, - "dir2": {"file2.txt": ("file",)}, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"file1.txt": ("file",)}, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({"dir1": {"file1.txt": ("file",)}}) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + {"dir1": {"file1.txt": ("file",)}} + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -314,20 +342,24 @@ def test_link_to_dir_directly_in_task_runs(self) -> None: ) def test_link_to_file_in_dir_in_task_runs(self) -> None: - starting_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - }) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"file1.txt": ("file",)}, - "dir2": {"file2.txt": ("file",)}, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"file1.txt": ("file",)}, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - }) - ending_task_runs_structure = FilesystemStructure({"dir1": {"file1.txt": ("file",)}}) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + ending_task_runs_structure = FilesystemStructure( + {"dir1": {"file1.txt": ("file",)}} + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -339,18 +371,26 @@ def test_link_to_file_in_dir_in_task_runs(self) -> None: ) def test_link_to_dir_in_dir_in_task_runs(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1/dir2")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, - "dir3": {"file3.txt": ("file",)}, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/dir2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, + "dir3": {"file3.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1/dir2")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/dir2")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -362,11 +402,15 @@ def test_link_to_dir_in_dir_in_task_runs(self) -> None: ) def test_link_to_link_crashes(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/symlink2")}) - starting_task_runs_structure = FilesystemStructure({ - "symlink2": ("symlink", "task_runs/file1.txt"), - "file1.txt": ("file",), - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/symlink2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "symlink2": ("symlink", "task_runs/file1.txt"), + "file1.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -376,20 +420,28 @@ def test_link_to_link_crashes(self) -> None: clean_workspace(MockDBGymConfig(self.scratchspace_path)) def test_safe_mode_link_to_dir_with_link(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, - "file1.txt": ("file",), - "file2.txt": ("file",), - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, + "file1.txt": ("file",), + "file2.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, - "file1.txt": ("file",), - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, + "file1.txt": ("file",), + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -401,30 +453,34 @@ def test_safe_mode_link_to_dir_with_link(self) -> None: ) def test_safe_mode_link_to_file_in_dir_with_link(self) -> None: - starting_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - }) - starting_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/file2.txt"), - }, - "file2.txt": ("file",), - "file3.txt": ("file",), - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/file2.txt"), + }, + "file2.txt": ("file",), + "file3.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - }) - ending_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/file2.txt"), - }, - "file2.txt": ("file",), - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/file2.txt"), + }, + "file2.txt": ("file",), + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -436,24 +492,32 @@ def test_safe_mode_link_to_file_in_dir_with_link(self) -> None: ) def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, - "dir2": { - "file2.txt": ("file",), - }, - "file3.txt": ("file",), - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, + "dir2": { + "file2.txt": ("file",), + }, + "file3.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, - "dir2": { - "file2.txt": ("file",), - }, - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, + "dir2": { + "file2.txt": ("file",), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -465,19 +529,27 @@ def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self) -> No ) def test_aggressive_mode_link_to_dir_with_link(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, - "file1.txt": ("file",), - "file2.txt": ("file",), - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, + "file1.txt": ("file",), + "file2.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", None)}, - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", None)}, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -489,13 +561,15 @@ def test_aggressive_mode_link_to_dir_with_link(self) -> None: ) def test_link_to_link_to_file_gives_error(self) -> None: - starting_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/symlink2") - }) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", "task_runs/file2.txt")}, - "file2.txt": ("file",), - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/symlink2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file2.txt")}, + "file2.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -507,12 +581,14 @@ def test_link_to_link_to_file_gives_error(self) -> None: clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") def test_multi_link_loop_gives_error(self) -> None: - starting_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/symlink2") - }) - starting_task_runs_structure = FilesystemStructure({ - "dir1": {"symlink2": ("symlink", "symlinks/symlink1")}, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/symlink2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "symlinks/symlink1")}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -524,7 +600,9 @@ def test_multi_link_loop_gives_error(self) -> None: clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") def test_link_self_loop_gives_error(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "symlinks/symlink1")}) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "symlinks/symlink1")} + ) starting_task_runs_structure = FilesystemStructure({}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure @@ -539,31 +617,39 @@ def test_link_self_loop_gives_error(self) -> None: def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs( self, ) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir2/file2.txt"), - }, - "dir2": { - "file2.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir2/file2.txt"), + }, + "dir2": { + "file2.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir2/file2.txt"), - }, - "dir2": { - "file2.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir2/file2.txt"), + }, + "dir2": { + "file2.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -577,23 +663,31 @@ def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs def test_dont_loop_infinitely_if_there_is_a_dir_in_runs_that_links_to_a_file_in_itself( self, ) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -605,23 +699,31 @@ def test_dont_loop_infinitely_if_there_is_a_dir_in_runs_that_links_to_a_file_in_ ) def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -633,21 +735,27 @@ def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self) -> None: ) def test_broken_symlink_has_no_effect(self) -> None: - starting_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - starting_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/non_existent_file.txt"), - }, - "dir2": {"file2.txt": ("file",)}, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/non_existent_file.txt"), + }, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = FilesystemStructure({"symlink1": ("symlink", "task_runs/dir1")}) - ending_task_runs_structure = FilesystemStructure({ - "dir1": {"file1.txt": ("file",), "symlink2": ("symlink", None)} - }) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + {"dir1": {"file1.txt": ("file",), "symlink2": ("symlink", None)}} + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -662,34 +770,40 @@ def test_broken_symlink_has_no_effect(self) -> None: def test_link_to_folder_outside_runs_that_contains_link_to_other_run_doesnt_save_other_run( self, ) -> None: - starting_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - }) - starting_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "external/dir3/file3.txt"), - }, - "dir2": {"file2.txt": ("file",)}, - }) + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "external/dir3/file3.txt"), + }, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - starting_structure["external"] = FilesystemStructure({ - "dir3": { - "file3.txt": ("file",), - "symlink3": ("symlink", "task_runs/dir2/file2.txt"), + starting_structure["external"] = FilesystemStructure( + { + "dir3": { + "file3.txt": ("file",), + "symlink3": ("symlink", "task_runs/dir2/file2.txt"), + } } - }) - ending_symlinks_structure = FilesystemStructure({ - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - }) - ending_task_runs_structure = FilesystemStructure({ - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "external/dir3/file3.txt"), + ) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "external/dir3/file3.txt"), + } } - }) + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) diff --git a/misc/utils.py b/misc/utils.py index c434e3fd..dd110405 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -4,7 +4,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import IO, Any, Callable, Tuple, Optional +from typing import IO, Any, Callable, Optional, Tuple import redis import yaml @@ -257,7 +257,7 @@ def cur_source_path(self, *dirs: str) -> Path: cur_path = cur_path / dir return cur_path - def cur_symlinks_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_symlinks_path(self, *dirs: str, mkdir: bool = False) -> Path: flattened_structure = "_".join(self.cur_path_list) cur_path = self.dbgym_symlinks_path / flattened_structure for dir in dirs: @@ -266,7 +266,7 @@ def cur_symlinks_path(self, *dirs: str, mkdir: bool=False) -> Path: cur_path.mkdir(parents=True, exist_ok=True) return cur_path - def cur_task_runs_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_task_runs_path(self, *dirs: str, mkdir: bool = False) -> Path: flattened_structure = "_".join(self.cur_path_list) cur_path = self.dbgym_this_run_path / flattened_structure for dir in dirs: @@ -275,22 +275,22 @@ def cur_task_runs_path(self, *dirs: str, mkdir: bool=False) -> Path: cur_path.mkdir(parents=True, exist_ok=True) return cur_path - def cur_symlinks_bin_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_symlinks_bin_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_symlinks_path("bin", *dirs, mkdir=mkdir) - def cur_symlinks_build_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_symlinks_build_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_symlinks_path("build", *dirs, mkdir=mkdir) - def cur_symlinks_data_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_symlinks_data_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_symlinks_path("data", *dirs, mkdir=mkdir) - def cur_task_runs_build_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_task_runs_build_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_task_runs_path("build", *dirs, mkdir=mkdir) - def cur_task_runs_data_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_task_runs_data_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_task_runs_path("data", *dirs, mkdir=mkdir) - def cur_task_runs_artifacts_path(self, *dirs: str, mkdir: bool=False) -> Path: + def cur_task_runs_artifacts_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_task_runs_path("artifacts", *dirs, mkdir=mkdir) @@ -405,7 +405,7 @@ def is_child_path(child_path: os.PathLike[str], parent_dpath: os.PathLike[str]) ) -def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode: str="r") -> IO[Any]: +def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode: str = "r") -> IO[Any]: """ Open a file and "save" it to [workspace]/task_runs/run_*/. It takes in a str | Path to match the interface of open(). @@ -541,7 +541,9 @@ def save_file(dbgym_cfg: DBGymConfig, fpath: Path) -> None: # TODO(phw2): refactor our manual symlinking in postgres/cli.py to use link_result() instead def link_result( - dbgym_cfg: DBGymConfig, result_fordpath: Path, custom_result_name: Optional[str] = None + dbgym_cfg: DBGymConfig, + result_fordpath: Path, + custom_result_name: Optional[str] = None, ) -> Path: """ result_fordpath must be a "result", meaning it was generated inside dbgym_cfg.dbgym_this_run_path. diff --git a/scripts/read_parquet.py b/scripts/read_parquet.py index 36b7cd35..7158ce6a 100644 --- a/scripts/read_parquet.py +++ b/scripts/read_parquet.py @@ -1,5 +1,5 @@ -from pathlib import Path import sys +from pathlib import Path import pandas as pd diff --git a/tune/protox/agent/base_class.py b/tune/protox/agent/base_class.py index 9b3dac26..e681cbb3 100644 --- a/tune/protox/agent/base_class.py +++ b/tune/protox/agent/base_class.py @@ -76,7 +76,9 @@ def _setup_learn( return total_timesteps @abstractmethod - def learn(self, env: AgentEnv, total_timesteps: int, tuning_mode: TuningMode) -> None: + def learn( + self, env: AgentEnv, total_timesteps: int, tuning_mode: TuningMode + ) -> None: """ Return a trained model. diff --git a/tune/protox/agent/build_trial.py b/tune/protox/agent/build_trial.py index 917a6ac2..e46129df 100644 --- a/tune/protox/agent/build_trial.py +++ b/tune/protox/agent/build_trial.py @@ -9,7 +9,10 @@ import numpy as np import torch from gymnasium.wrappers import FlattenObservation # type: ignore -from gymnasium.wrappers import NormalizeObservation, NormalizeReward # type: ignore[attr-defined] +from gymnasium.wrappers import ( # type: ignore[attr-defined] + NormalizeObservation, + NormalizeReward, +) from torch import nn from torch.optim import Adam # type: ignore[attr-defined] @@ -433,9 +436,7 @@ def _build_agent( policy_weight_adjustment=hpo_params["policy_weight_adjustment"], ) - actor_optimizer = Adam( - actor.parameters(), lr=hpo_params["learning_rate"] - ) + actor_optimizer = Adam(actor.parameters(), lr=hpo_params["learning_rate"]) critic = ContinuousCritic( observation_space=observation_space, diff --git a/tune/protox/agent/hpo.py b/tune/protox/agent/hpo.py index 251d526e..b4acddb5 100644 --- a/tune/protox/agent/hpo.py +++ b/tune/protox/agent/hpo.py @@ -187,7 +187,7 @@ def __init__( "--agent", type=str, default="wolp", - help=f"The RL algorithm to use for the tuning agent." + help=f"The RL algorithm to use for the tuning agent.", ) @click.option( "--max-concurrent", @@ -659,6 +659,8 @@ def cleanup(self) -> None: # https://discuss.ray.io/t/using-static-variables-to-control-trainable-subclass-in-ray-tune/808/4) # If you don't create the class with a function, it doesn't work due to how Ray serializes classes global_dbgym_cfg: DBGymConfig + + def create_tune_opt_class(dbgym_cfg_param: DBGymConfig) -> Type[Trainable]: global global_dbgym_cfg global_dbgym_cfg = dbgym_cfg_param diff --git a/tune/protox/agent/replay.py b/tune/protox/agent/replay.py index 63b39284..90a4b7ce 100644 --- a/tune/protox/agent/replay.py +++ b/tune/protox/agent/replay.py @@ -6,10 +6,10 @@ replayed tuning run is not. """ -from datetime import datetime import json import logging import pickle +from datetime import datetime from pathlib import Path from typing import Any, Optional, Set, cast @@ -249,7 +249,9 @@ def _is_tuning_step_line(line: str) -> bool: num_lines += 1 # A convenience wrapper around execute_workload() which fills in the arguments properly and processes the return values. - def _execute_workload_wrapper(actions_info: ActionsInfo) -> tuple[int, int, bool, float]: + def _execute_workload_wrapper( + actions_info: ActionsInfo, + ) -> tuple[int, int, bool, float]: logging.info( f"\n\nfetch_server_knobs(): {fetch_server_knobs(pg_env.pg_conn.conn(), action_space.get_knob_space().tables, action_space.get_knob_space().knobs, pg_env.workload.queries)}\n\n" ) diff --git a/tune/protox/agent/wolp/policies.py b/tune/protox/agent/wolp/policies.py index 906a3750..4882cbd3 100644 --- a/tune/protox/agent/wolp/policies.py +++ b/tune/protox/agent/wolp/policies.py @@ -244,7 +244,9 @@ def train_critic( self.critic_optimizer.zero_grad() assert not th.isnan(critic_loss).any() critic_loss.backward() # type: ignore - th.nn.utils.clip_grad_norm_(list(self.critic.parameters()), self.grad_clip, error_if_nonfinite=True) + th.nn.utils.clip_grad_norm_( + list(self.critic.parameters()), self.grad_clip, error_if_nonfinite=True + ) self.critic.check_grad() self.critic_optimizer.step() return critic_loss @@ -282,7 +284,9 @@ def train_actor(self, replay_data: ReplayBufferSamples) -> Any: self.actor_optimizer.zero_grad() assert not th.isnan(actor_loss).any() actor_loss.backward() # type: ignore - th.nn.utils.clip_grad_norm_(list(self.actor.parameters()), self.grad_clip, error_if_nonfinite=True) + th.nn.utils.clip_grad_norm_( + list(self.actor.parameters()), self.grad_clip, error_if_nonfinite=True + ) self.actor.check_grad() self.actor_optimizer.step() return actor_loss diff --git a/tune/protox/embedding/train.py b/tune/protox/embedding/train.py index 67609f56..9b3dc944 100644 --- a/tune/protox/embedding/train.py +++ b/tune/protox/embedding/train.py @@ -41,10 +41,7 @@ @click.pass_obj # generic args -@click.argument( - "benchmark-name", - type=str -) +@click.argument("benchmark-name", type=str) @click.option( "--seed-start", type=int, @@ -107,19 +104,11 @@ default=40, help=f"The # of times to specific hyperparameter configs to sample from the hyperparameter search space and train embedding models with.", ) -@click.option( - "--train-size", - type=float, - default=0.99, - help=f"TODO(wz2)" -) +@click.option("--train-size", type=float, default=0.99, help=f"TODO(wz2)") # analyze args @click.option( - "--start-epoch", - type=int, - default=0, - help="The epoch to start analyzing models at." + "--start-epoch", type=int, default=0, help="The epoch to start analyzing models at." ) @click.option( "--batch-size", @@ -178,21 +167,12 @@ help="The number of indexes whose errors to compute during _attach().", ) @click.option( - "--num-curate", - type=int, - default=1, - help="The number of models to curate" + "--num-curate", type=int, default=1, help="The number of models to curate" ) # TODO(wz2): why would we want to curate more than one? @click.option( - "--allow-all", - is_flag=True, - help="Whether to curate within or across parts." -) -@click.option("--flatten-idx", - type=int, - default=0, - help="TODO(wz2)" + "--allow-all", is_flag=True, help="Whether to curate within or across parts." ) +@click.option("--flatten-idx", type=int, default=0, help="TODO(wz2)") def train( dbgym_cfg: DBGymConfig, benchmark_name: str, diff --git a/tune/protox/embedding/train_args.py b/tune/protox/embedding/train_args.py index 21b2917a..c86a6392 100644 --- a/tune/protox/embedding/train_args.py +++ b/tune/protox/embedding/train_args.py @@ -72,7 +72,7 @@ def __init__( idx_limit: int, num_curate: int, allow_all: bool, - flatten_idx: int + flatten_idx: int, ) -> None: self.recon = recon self.latent_dim = latent_dim diff --git a/tune/protox/env/logger.py b/tune/protox/env/logger.py index 36750f10..6cf2a4fe 100644 --- a/tune/protox/env/logger.py +++ b/tune/protox/env/logger.py @@ -25,7 +25,7 @@ def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> T: ret = f(*args, **kwargs) # TODO(wz2): This is a hack to get a logger instance. - first_arg = args[0] # Ignore the indexing type error + first_arg = args[0] # Ignore the indexing type error assert hasattr(first_arg, "logger"), print(first_arg, type(first_arg)) if first_arg.logger is None: diff --git a/tune/protox/env/mqo/mqo_wrapper.py b/tune/protox/env/mqo/mqo_wrapper.py index 50b9dc39..6f39a43a 100644 --- a/tune/protox/env/mqo/mqo_wrapper.py +++ b/tune/protox/env/mqo/mqo_wrapper.py @@ -163,7 +163,7 @@ def __init__( self.logger = logger def _update_best_observed( - self, query_metric_data: dict[str, BestQueryRun], force_overwrite: bool=False + self, query_metric_data: dict[str, BestQueryRun], force_overwrite: bool = False ) -> None: if query_metric_data is not None: for qid, best_run in query_metric_data.items(): diff --git a/tune/protox/env/pg_env.py b/tune/protox/env/pg_env.py index 5f94e587..b22a9521 100644 --- a/tune/protox/env/pg_env.py +++ b/tune/protox/env/pg_env.py @@ -304,10 +304,12 @@ def step_execute( "query_metric_data": query_metric_data, "reward": reward, "results_dpath": results_dpath, - "actions_info": ActionsInfo({ - "all_holon_action_variations": all_holon_action_variations, - "best_observed_holon_action": None - }), + "actions_info": ActionsInfo( + { + "all_holon_action_variations": all_holon_action_variations, + "best_observed_holon_action": None, + } + ), } ) ) diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index d79ec8db..35d3d8a0 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -206,4 +206,4 @@ class EnvInfoDict(TypedDict, total=False): # New state container. state_container: HolonStateContainer # What the LSC associated with the action is. - lsc: float \ No newline at end of file + lsc: float diff --git a/tune/protox/tests/test_index_space.py b/tune/protox/tests/test_index_space.py index 977e6764..9ccfd73e 100644 --- a/tune/protox/tests/test_index_space.py +++ b/tune/protox/tests/test_index_space.py @@ -13,11 +13,11 @@ class IndexSpaceTests(unittest.TestCase): @staticmethod def load( - config_path: Path=Path( + config_path: Path = Path( "tune/protox/tests/unittest_benchmark_configs/unittest_tpch.yaml" ).resolve(), - aux_type: bool=True, - aux_include: bool=True, + aux_type: bool = True, + aux_include: bool = True, ) -> tuple[Workload, IndexSpace]: # don't call open_and_save() because this is a unittest with open(config_path, "r") as f: @@ -86,7 +86,9 @@ def test_neighborhood(self) -> None: _, isa = IndexSpaceTests.load(aux_type=False, aux_include=False) act = isa.sample(mask={"table_idx": 2, "col_idx": 1}) - act = IndexSpaceRawSample(tuple([0, *act, np.zeros(i.max_inc_columns, dtype=np.float32)])) + act = IndexSpaceRawSample( + tuple([0, *act, np.zeros(i.max_inc_columns, dtype=np.float32)]) + ) self.assertTrue(check_subspace(i, act)) neighbors = i.policy.structural_neighbors(act) diff --git a/tune/protox/tests/test_workload.py b/tune/protox/tests/test_workload.py index 79c6c45b..04a0f980 100644 --- a/tune/protox/tests/test_workload.py +++ b/tune/protox/tests/test_workload.py @@ -1,7 +1,7 @@ import json -from typing import Any, Tuple import unittest from pathlib import Path +from typing import Any, Tuple import yaml @@ -46,7 +46,9 @@ def load(config_file: str, workload_path: Path) -> tuple[Workload, IndexSpace]: ) return w, i - def diff_classmapping(self, ref: dict[TableColTuple, int], target: dict[TableColTuple, int]) -> None: + def diff_classmapping( + self, ref: dict[TableColTuple, int], target: dict[TableColTuple, int] + ) -> None: for k, v in ref.items(): self.assertTrue(k in target, msg=f"{k} is missing.") self.assertTrue(v == target[k]) diff --git a/tune/protox/tests/test_workload_utils.py b/tune/protox/tests/test_workload_utils.py index 3f433f37..be2fd9a8 100644 --- a/tune/protox/tests/test_workload_utils.py +++ b/tune/protox/tests/test_workload_utils.py @@ -2,8 +2,12 @@ import pglast -from tune.protox.env.types import QueryType, AttrTableListMap -from tune.protox.env.util.workload_analysis import extract_aliases, extract_sqltypes, extract_columns +from tune.protox.env.types import AttrTableListMap, QueryType +from tune.protox.env.util.workload_analysis import ( + extract_aliases, + extract_columns, + extract_sqltypes, +) class WorkloadUtilsTests(unittest.TestCase): @@ -17,69 +21,71 @@ class WorkloadUtilsTests(unittest.TestCase): "nation", "region", ] - TPCH_ALL_ATTRIBUTES = AttrTableListMap({ - "r_regionkey": ["region"], - "r_name": ["region"], - "r_comment": ["region"], - "n_nationkey": ["nation"], - "n_name": ["nation"], - "n_regionkey": ["nation"], - "n_comment": ["nation"], - "p_partkey": ["part"], - "p_name": ["part"], - "p_mfgr": ["part"], - "p_brand": ["part"], - "p_type": ["part"], - "p_size": ["part"], - "p_container": ["part"], - "p_retailprice": ["part"], - "p_comment": ["part"], - "s_suppkey": ["supplier"], - "s_name": ["supplier"], - "s_address": ["supplier"], - "s_nationkey": ["supplier"], - "s_phone": ["supplier"], - "s_acctbal": ["supplier"], - "s_comment": ["supplier"], - "ps_partkey": ["partsupp"], - "ps_suppkey": ["partsupp"], - "ps_availqty": ["partsupp"], - "ps_supplycost": ["partsupp"], - "ps_comment": ["partsupp"], - "c_custkey": ["customer"], - "c_name": ["customer"], - "c_address": ["customer"], - "c_nationkey": ["customer"], - "c_phone": ["customer"], - "c_acctbal": ["customer"], - "c_mktsegment": ["customer"], - "c_comment": ["customer"], - "o_orderkey": ["orders"], - "o_custkey": ["orders"], - "o_orderstatus": ["orders"], - "o_totalprice": ["orders"], - "o_orderdate": ["orders"], - "o_orderpriority": ["orders"], - "o_clerk": ["orders"], - "o_shippriority": ["orders"], - "o_comment": ["orders"], - "l_orderkey": ["lineitem"], - "l_partkey": ["lineitem"], - "l_suppkey": ["lineitem"], - "l_linenumber": ["lineitem"], - "l_quantity": ["lineitem"], - "l_extendedprice": ["lineitem"], - "l_discount": ["lineitem"], - "l_tax": ["lineitem"], - "l_returnflag": ["lineitem"], - "l_linestatus": ["lineitem"], - "l_shipdate": ["lineitem"], - "l_commitdate": ["lineitem"], - "l_receiptdate": ["lineitem"], - "l_shipinstruct": ["lineitem"], - "l_shipmode": ["lineitem"], - "l_comment": ["lineitem"], - }) + TPCH_ALL_ATTRIBUTES = AttrTableListMap( + { + "r_regionkey": ["region"], + "r_name": ["region"], + "r_comment": ["region"], + "n_nationkey": ["nation"], + "n_name": ["nation"], + "n_regionkey": ["nation"], + "n_comment": ["nation"], + "p_partkey": ["part"], + "p_name": ["part"], + "p_mfgr": ["part"], + "p_brand": ["part"], + "p_type": ["part"], + "p_size": ["part"], + "p_container": ["part"], + "p_retailprice": ["part"], + "p_comment": ["part"], + "s_suppkey": ["supplier"], + "s_name": ["supplier"], + "s_address": ["supplier"], + "s_nationkey": ["supplier"], + "s_phone": ["supplier"], + "s_acctbal": ["supplier"], + "s_comment": ["supplier"], + "ps_partkey": ["partsupp"], + "ps_suppkey": ["partsupp"], + "ps_availqty": ["partsupp"], + "ps_supplycost": ["partsupp"], + "ps_comment": ["partsupp"], + "c_custkey": ["customer"], + "c_name": ["customer"], + "c_address": ["customer"], + "c_nationkey": ["customer"], + "c_phone": ["customer"], + "c_acctbal": ["customer"], + "c_mktsegment": ["customer"], + "c_comment": ["customer"], + "o_orderkey": ["orders"], + "o_custkey": ["orders"], + "o_orderstatus": ["orders"], + "o_totalprice": ["orders"], + "o_orderdate": ["orders"], + "o_orderpriority": ["orders"], + "o_clerk": ["orders"], + "o_shippriority": ["orders"], + "o_comment": ["orders"], + "l_orderkey": ["lineitem"], + "l_partkey": ["lineitem"], + "l_suppkey": ["lineitem"], + "l_linenumber": ["lineitem"], + "l_quantity": ["lineitem"], + "l_extendedprice": ["lineitem"], + "l_discount": ["lineitem"], + "l_tax": ["lineitem"], + "l_returnflag": ["lineitem"], + "l_linestatus": ["lineitem"], + "l_shipdate": ["lineitem"], + "l_commitdate": ["lineitem"], + "l_receiptdate": ["lineitem"], + "l_shipinstruct": ["lineitem"], + "l_shipmode": ["lineitem"], + "l_comment": ["lineitem"], + } + ) TPCH_Q1 = """ select l_returnflag, diff --git a/util/pg.py b/util/pg.py index 2cf4506d..31e4f0c9 100644 --- a/util/pg.py +++ b/util/pg.py @@ -41,7 +41,7 @@ def sql_file_execute(dbgym_cfg: DBGymConfig, conn: Connection, filepath: Path) - # The reason pgport is an argument is because when doing agnet HPO, we want to run multiple instances of Postgres # at the same time. In this situation, they need to have different ports -def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool=True) -> str: +def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True) -> str: connstr_suffix = f"{DBGYM_POSTGRES_USER}:{DBGYM_POSTGRES_PASS}@localhost:{pgport}/{DBGYM_POSTGRES_DBNAME}" # use_psycopg means whether or not we use the psycopg.connect() function # counterintuively, you *don't* need psycopg in the connection string if you *are* @@ -50,11 +50,13 @@ def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool=True) -> return connstr_prefix + "://" + connstr_suffix -def create_conn(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool=True) -> Connection: +def create_conn( + pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True +) -> Connection: connstr = get_connstr(use_psycopg=use_psycopg, pgport=pgport) if use_psycopg: psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None) - engine = create_engine(connstr, creator=lambda : psycopg_conn) + engine = create_engine(connstr, creator=lambda: psycopg_conn) return engine.connect() else: engine = create_engine( diff --git a/util/shell.py b/util/shell.py index 29a03aff..d20097ec 100644 --- a/util/shell.py +++ b/util/shell.py @@ -1,14 +1,19 @@ import logging import os -from pathlib import Path import subprocess +from pathlib import Path from typing import Optional shell_util_logger = logging.getLogger("shell_util") shell_util_logger.setLevel(logging.INFO) -def subprocess_run(c: str, cwd: Optional[Path]=None, check_returncode: bool=True, verbose: bool=True) -> subprocess.Popen[str]: +def subprocess_run( + c: str, + cwd: Optional[Path] = None, + check_returncode: bool = True, + verbose: bool = True, +) -> subprocess.Popen[str]: cwd_msg = f"(cwd: {cwd if cwd is not None else os.getcwd()})" if verbose: From bada7335972125f632347dcabfa5ba0d5b3dbb46 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:41:18 +0000 Subject: [PATCH 49/60] added mypy to CI --- .github/workflows/tests_ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/tests_ci.yml b/.github/workflows/tests_ci.yml index 948b2a31..46295d8b 100644 --- a/.github/workflows/tests_ci.yml +++ b/.github/workflows/tests_ci.yml @@ -36,6 +36,10 @@ jobs: run: | ./scripts/check_format.sh + - name: Static type checking + run: | + mypy --config-file scripts/mypy.ini . + - name: Run unit tests run: | . "$HOME/.cargo/env" From e0944096f0eba91b48add59669115296e31873b1 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 21:45:55 +0000 Subject: [PATCH 50/60] small type error --- tune/protox/env/space/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tune/protox/env/space/utils.py b/tune/protox/env/space/utils.py index 734a4ca8..0977a906 100644 --- a/tune/protox/env/space/utils.py +++ b/tune/protox/env/space/utils.py @@ -224,7 +224,7 @@ def fetch_server_knobs( def fetch_server_indexes( connection: Connection[Any], tables: list[str] -) -> typing.tuple[TableAttrListMap, ServerTableIndexMetadata]: +) -> tuple[TableAttrListMap, ServerTableIndexMetadata]: rel_metadata = TableAttrListMap({t: [] for t in tables}) existing_indexes = ServerTableIndexMetadata({}) with connection.cursor(row_factory=dict_row) as cursor: From 5ec9ff66cee71ba5818cc7ea6998f00f26693d56 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 22:14:25 +0000 Subject: [PATCH 51/60] fixed issues around psycopg and sqlalchemy conn --- dbms/postgres/cli.py | 12 ++++++------ util/pg.py | 45 ++++++++++++++++++++++++-------------------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/dbms/postgres/cli.py b/dbms/postgres/cli.py index ff729b73..2d968a98 100644 --- a/dbms/postgres/cli.py +++ b/dbms/postgres/cli.py @@ -13,7 +13,7 @@ from typing import Optional import click -from sqlalchemy import Connection +import sqlalchemy from benchmark.tpch.load_info import TpchLoadInfo from dbms.load_info_base_class import LoadInfoBaseClass @@ -36,9 +36,9 @@ DEFAULT_POSTGRES_DBNAME, DEFAULT_POSTGRES_PORT, SHARED_PRELOAD_LIBRARIES, - conn_execute, - create_conn, + create_sqlalchemy_conn, sql_file_execute, + sqlalchemy_conn_execute, ) from util.shell import subprocess_run @@ -249,7 +249,7 @@ def _generic_dbdata_setup(dbgym_cfg: DBGymConfig) -> None: def _load_benchmark_into_dbdata( dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float ) -> None: - with create_conn(use_psycopg=False) as conn: + with create_sqlalchemy_conn() as conn: if benchmark_name == "tpch": load_info = TpchLoadInfo(dbgym_cfg, scale_factor) else: @@ -261,13 +261,13 @@ def _load_benchmark_into_dbdata( def _load_into_dbdata( - dbgym_cfg: DBGymConfig, conn: Connection, load_info: LoadInfoBaseClass + dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, load_info: LoadInfoBaseClass ) -> None: sql_file_execute(dbgym_cfg, conn, load_info.get_schema_fpath()) # truncate all tables first before even loading a single one for table, _ in load_info.get_tables_and_fpaths(): - conn_execute(conn, f"TRUNCATE {table} CASCADE") + sqlalchemy_conn_execute(conn, f"TRUNCATE {table} CASCADE") # then, load the tables for table, table_fpath in load_info.get_tables_and_fpaths(): with open_and_save(dbgym_cfg, table_fpath, "r") as table_csv: diff --git a/util/pg.py b/util/pg.py index 31e4f0c9..9d8a370e 100644 --- a/util/pg.py +++ b/util/pg.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import Any, List +from typing import Any, List, NewType, Union import pglast import psycopg -from sqlalchemy import Connection, Engine, create_engine, text -from sqlalchemy.engine import CursorResult +import sqlalchemy +from sqlalchemy import create_engine, text from misc.utils import DBGymConfig, open_and_save @@ -16,7 +16,9 @@ SHARED_PRELOAD_LIBRARIES = "boot,pg_hint_plan,pg_prewarm" -def conn_execute(conn: Connection, sql: str) -> CursorResult[Any]: +def sqlalchemy_conn_execute( + conn: sqlalchemy.Connection, sql: str +) -> sqlalchemy.engine.CursorResult[Any]: return conn.execute(text(sql)) @@ -34,9 +36,11 @@ def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> list[str]: return queries -def sql_file_execute(dbgym_cfg: DBGymConfig, conn: Connection, filepath: Path) -> None: +def sql_file_execute( + dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, filepath: Path +) -> None: for sql in sql_file_queries(dbgym_cfg, filepath): - conn_execute(conn, sql) + sqlalchemy_conn_execute(conn, sql) # The reason pgport is an argument is because when doing agnet HPO, we want to run multiple instances of Postgres @@ -50,17 +54,18 @@ def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True) - return connstr_prefix + "://" + connstr_suffix -def create_conn( - pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True -) -> Connection: - connstr = get_connstr(use_psycopg=use_psycopg, pgport=pgport) - if use_psycopg: - psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None) - engine = create_engine(connstr, creator=lambda: psycopg_conn) - return engine.connect() - else: - engine = create_engine( - connstr, - execution_options={"isolation_level": "AUTOCOMMIT"}, - ) - return engine.connect() +def create_psycopg_conn(pgport: int = DEFAULT_POSTGRES_PORT) -> psycopg.Connection[Any]: + connstr = get_connstr(use_psycopg=True, pgport=pgport) + psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None) + return psycopg_conn + + +def create_sqlalchemy_conn( + pgport: int = DEFAULT_POSTGRES_PORT, +) -> sqlalchemy.Connection: + connstr = get_connstr(use_psycopg=True, pgport=pgport) + engine = create_engine( + connstr, + execution_options={"isolation_level": "AUTOCOMMIT"}, + ) + return engine.connect() From 632d2a265c09a9f024c47ea0dc7286d2af6dee9e Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Mon, 2 Sep 2024 23:25:56 +0000 Subject: [PATCH 52/60] made a few fixes to select.py --- scripts/mypy.ini | 3 -- tune/protox/embedding/datagen.py | 5 ++-- tune/protox/embedding/select.py | 50 ++++++++++++++------------------ 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/scripts/mypy.ini b/scripts/mypy.ini index e4c360df..98ef8d68 100644 --- a/scripts/mypy.ini +++ b/scripts/mypy.ini @@ -1,6 +1,3 @@ [mypy] strict = True ignore_missing_imports = True - -[mypy-tune.protox.embedding.*] -ignore_errors = True \ No newline at end of file diff --git a/tune/protox/embedding/datagen.py b/tune/protox/embedding/datagen.py index c415d9fc..a1315d57 100644 --- a/tune/protox/embedding/datagen.py +++ b/tune/protox/embedding/datagen.py @@ -15,7 +15,7 @@ import yaml from sklearn.preprocessing import quantile_transform -from dbms.postgres.cli import create_conn, start_postgres, stop_postgres +from dbms.postgres.cli import start_postgres, stop_postgres from misc.utils import ( BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER, @@ -39,6 +39,7 @@ from tune.protox.env.space.primitive_space.index_space import IndexSpace from tune.protox.env.types import QueryType from tune.protox.env.workload import Workload +from util.pg import create_psycopg_conn from util.shell import subprocess_run # FUTURE(oltp) @@ -780,7 +781,7 @@ def _produce_index_data( # there are no indexes to generate. return - with create_conn() as connection: + with create_psycopg_conn() as connection: _fetch_server_indexes(connection) if generate_costs: try: diff --git a/tune/protox/embedding/select.py b/tune/protox/embedding/select.py index 1e28dce0..db23fdd7 100644 --- a/tune/protox/embedding/select.py +++ b/tune/protox/embedding/select.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +from pandas import DataFrame import tqdm from misc.utils import DBGymConfig, default_embedder_dname, link_result @@ -15,12 +16,6 @@ ) -class DotDict(dict): - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - def select_best_embeddings( dbgym_cfg: DBGymConfig, generic_args: EmbeddingTrainGenericArgs, @@ -28,9 +23,7 @@ def select_best_embeddings( ) -> None: data = _load_data(dbgym_cfg, select_args) - if generic_args.traindata_path is not None and os.path.exists( - generic_args.traindata_path - ): + if generic_args.traindata_path is not None and generic_args.traindata_path.exists(): raw_data = pd.read_parquet(generic_args.traindata_path) data = _attach(data, raw_data, select_args.idx_limit) @@ -97,8 +90,8 @@ def select_best_embeddings( info_txt.close() -def _load_data(dbgym_cfg, select_args): - data = [] +def _load_data(dbgym_cfg: DBGymConfig, select_args: EmbeddingSelectArgs) -> DataFrame: + stat_infos = [] stats = [s for s in dbgym_cfg.dbgym_this_run_path.rglob(STATS_FNAME)] print(f"stats={stats}") for stat in stats: @@ -147,10 +140,9 @@ def recurse_set(source, target): info["ranges_file"] = str(Path(stat).parent / RANGES_FNAME) - data.append(info) + stat_infos.append(info) - print(f"data={data}") - data = pd.DataFrame(data) + data = DataFrame(stat_infos) data = data.loc[:, ~(data == data.iloc[0]).all()] if "output_scale" not in data: @@ -162,13 +154,13 @@ def recurse_set(source, target): return data -def _attach(data, raw_data, num_limit=0): +def _attach(data, raw_data, num_limit: int=0) -> DataFrame: # As the group index goes up, the perf should go up (i.e., bounds should tighten) filtered_data = {} new_data = [] for tup in tqdm.tqdm(data.itertuples(), total=data.shape[0]): - tup = DotDict({k: getattr(tup, k) for k in data.columns}) - if raw_data is not None and Path(tup.ranges_file).exists(): + tup_dict = {k: getattr(tup, k) for k in data.columns} + if raw_data is not None and Path(tup_dict["ranges_file"]).exists(): def compute_dist_score(current_dists, base, upper): nonlocal filtered_data @@ -202,7 +194,7 @@ def compute_dist_score(current_dists, base, upper): return error # don't use open_and_save() because we generated ranges in this run - with open(tup.ranges_file, "r") as f: + with open(tup_dict["ranges_file"], "r") as f: errors = [] drange = (None, None) current_dists = {} @@ -219,9 +211,9 @@ def compute_dist_score(current_dists, base, upper): break if drange[0] is None: - drange = (1.0 - tup.bias_separation, 1.01) + drange = (1.0 - tup_dict["bias_separation"], 1.01) else: - drange = (drange[0] - tup.bias_separation, drange[0]) + drange = (drange[0] - tup_dict["bias_separation"], drange[0]) current_dists = {} else: @@ -232,19 +224,19 @@ def compute_dist_score(current_dists, base, upper): if len(current_dists) > 0: # Put the error in. errors.append( - compute_dist_score(current_dists, 0.0, tup.bias_separation) + compute_dist_score(current_dists, 0.0, tup_dict["bias_separation"]) ) - tup["idx_class_errors"] = ",".join( + tup_dict["idx_class_errors"] = ",".join( [str(np.round(e, 2)) for e in errors] ) for i, e in enumerate(errors): - tup[f"idx_class_error{i}"] = np.round(e, 2) + tup_dict[f"idx_class_error{i}"] = np.round(e, 2) if len(errors) > 0: - tup["idx_class_mean_error"] = np.mean(errors) - tup["idx_class_total_error"] = np.sum(errors) - tup["idx_class_min_error"] = np.min(errors) - tup["idx_class_max_error"] = np.max(errors) - new_data.append(dict(tup)) - return pd.DataFrame(new_data) + tup_dict["idx_class_mean_error"] = np.mean(errors) + tup_dict["idx_class_total_error"] = np.sum(errors) + tup_dict["idx_class_min_error"] = np.min(errors) + tup_dict["idx_class_max_error"] = np.max(errors) + new_data.append(tup_dict) + return DataFrame(new_data) From 181c6e9ddf1e9149db3c75e40e36150946788227 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 00:05:13 +0000 Subject: [PATCH 53/60] fixed tune/protox/embedding/datagen.py --- tune/protox/embedding/datagen.py | 228 +++++++++++++++---------------- tune/protox/env/workload.py | 2 +- 2 files changed, 113 insertions(+), 117 deletions(-) diff --git a/tune/protox/embedding/datagen.py b/tune/protox/embedding/datagen.py index a1315d57..b5ff7897 100644 --- a/tune/protox/embedding/datagen.py +++ b/tune/protox/embedding/datagen.py @@ -8,10 +8,12 @@ from itertools import chain, combinations from multiprocessing import Pool from pathlib import Path +from typing import Any, NewType, Optional, cast import click import numpy as np import pandas as pd +import psycopg import yaml from sklearn.preprocessing import quantile_transform @@ -37,7 +39,7 @@ ) from tune.protox.embedding.loss import COST_COLUMNS from tune.protox.env.space.primitive_space.index_space import IndexSpace -from tune.protox.env.types import QueryType +from tune.protox.env.types import QuerySpec, QueryType, TableAttrAccessSetsMap, TableAttrListMap from tune.protox.env.workload import Workload from util.pg import create_psycopg_conn from util.shell import subprocess_run @@ -50,6 +52,9 @@ # pass +QueryBatches = NewType("QueryBatches", list[tuple[str, list[tuple[QueryType, str]], Any]]) + + # click steup @click.command() @click.pass_obj @@ -75,6 +80,7 @@ ) @click.option( "--scale-factor", + type=float, default=1.0, help=f"The scale factor used when generating the data of the benchmark.", ) @@ -87,8 +93,8 @@ # TODO(phw2): need to run pgtune before gathering data @click.option( "--pristine-dbdata-snapshot-path", - default=None, type=Path, + default=None, help=f"The path to the .tgz snapshot of the dbdata directory to build an embedding space over. The default is {default_pristine_dbdata_snapshot_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER)}.", ) @click.option( @@ -99,60 +105,60 @@ ) @click.option( "--dbdata-parent-dpath", - default=None, type=Path, + default=None, help=f"The path to the parent directory of the dbdata which will be actively tuned. The default is {default_pristine_dbdata_snapshot_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER)}.", ) @click.option( "--benchmark-config-path", - default=None, type=Path, + default=None, help=f"The path to the .yaml config file for the benchmark. The default is {default_benchmark_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--workload-path", - default=None, type=Path, + default=None, help=f"The path to the directory that specifies the workload (such as its queries and order of execution). The default is {default_workload_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.", ) @click.option( "--seed", - default=None, type=int, + default=None, help="The seed used for all sources of randomness (random, np, torch, etc.). The default is a random value.", ) # dir gen args @click.option( "--leading-col-tbls", - default=None, type=str, + default=None, help='All tables included here will have indexes created s.t. each column is represented equally often as the "leading column" of the index.', ) # TODO(wz2): what if we sample tbl_sample_limit / len(cols) for tables in leading_col_tbls? this way, tbl_sample_limit will always represent the total # of indexes created on that table. currently the description of the param is a bit weird as you can see @click.option( "--default-sample-limit", - default=2048, type=int, + default=2048, help="The default sample limit of all tables, used unless override sample limit is specified. If the table is in --leading-col-tbls, sample limit is # of indexes to sample per column for that table table. If the table is in --leading-col-tbls, sample limit is the # of indexes to sample total for that table.", ) @click.option( "--override-sample-limits", - default=None, type=str, + default=None, help='Override the sample limit for specific tables. An example input would be "lineitem,32768,orders,4096".', ) # TODO(wz2): if I'm just outputting out.parquet instead of the full directory, do we even need file limit at all? @click.option( "--file-limit", - default=1024, type=int, + default=1024, help="The max # of data points (one data point = one hypothetical index) per file", ) @click.option( "--max-concurrent", - default=None, type=int, + default=None, help="The max # of concurrent threads that will be creating hypothetical indexes. The default is `nproc`.", ) # TODO(wz2): when would we not want to generate costs? @@ -161,33 +167,33 @@ # file gen args @click.option("--table-shape", is_flag=True, help="TODO(wz2)") @click.option("--dual-class", is_flag=True, help="TODO(wz2)") -@click.option("--pad-min", default=None, type=int, help="TODO(wz2)") -@click.option("--rebias", default=0, type=float, help="TODO(wz2)") +@click.option("--pad-min", type=int, default=None, help="TODO(wz2)") +@click.option("--rebias", type=float, default=0, help="TODO(wz2)") def datagen( - dbgym_cfg, - benchmark_name, - seed_start, - seed_end, - query_subset, - scale_factor, - pgbin_path, - pristine_dbdata_snapshot_path, - intended_dbdata_hardware, - dbdata_parent_dpath, - benchmark_config_path, - workload_path, - seed, - leading_col_tbls, - default_sample_limit, - override_sample_limits, - file_limit, - max_concurrent, - no_generate_costs, - table_shape, - dual_class, - pad_min, - rebias, -): + dbgym_cfg: DBGymConfig, + benchmark_name: str, + seed_start: int, + seed_end: int, + query_subset: str, + scale_factor: float, + pgbin_path: Optional[Path], + pristine_dbdata_snapshot_path: Optional[Path], + intended_dbdata_hardware: str, + dbdata_parent_dpath: Optional[Path], + benchmark_config_path: Optional[Path], + workload_path: Optional[Path], + seed: Optional[int], + leading_col_tbls: str, + default_sample_limit: int, + override_sample_limits: Optional[str], + file_limit: int, + max_concurrent: Optional[int], + no_generate_costs: bool, + table_shape: bool, + dual_class: bool, + pad_min: int, + rebias: float, +) -> None: """ Samples the effects of indexes on the workload as estimated by HypoPG. Outputs all this data as a .parquet file in the run_*/ dir. @@ -220,8 +226,9 @@ def datagen( ) if max_concurrent is None: max_concurrent = os.cpu_count() + assert max_concurrent is not None if seed is None: - seed = random.randint(0, 1e8) + seed = random.randint(0, int(1e8)) # Convert all input paths to absolute paths workload_path = conv_inputpath_to_realabspath(dbgym_cfg, workload_path) @@ -247,22 +254,19 @@ def datagen( assert False # Process the "data structure" args - leading_col_tbls = [] if leading_col_tbls is None else leading_col_tbls.split(",") + leading_col_tbls_parsed: list[str] = [] if leading_col_tbls is None else leading_col_tbls.split(",") # I chose to only use the "," delimiter in override_sample_limits_str, so the dictionary is encoded as [key],[value],[key],[value] # I felt this was better than introducing a new delimiter which might conflict with the name of a table - if override_sample_limits is None: - override_sample_limits = dict() - else: - override_sample_limits_str = override_sample_limits - override_sample_limits = dict() - override_sample_limits_str_split = override_sample_limits_str.split(",") + override_sample_limits_parsed: dict[str, int] = dict() + if override_sample_limits is not None: + override_sample_limits_str_split = override_sample_limits.split(",") assert ( len(override_sample_limits_str_split) % 2 == 0 - ), f'override_sample_limits ("{override_sample_limits_str}") does not have an even number of values' + ), f'override_sample_limits ("{override_sample_limits}") does not have an even number of values' for i in range(0, len(override_sample_limits_str_split), 2): tbl = override_sample_limits_str_split[i] limit = int(override_sample_limits_str_split[i + 1]) - override_sample_limits[tbl] = limit + override_sample_limits_parsed[tbl] = limit # Group args together to reduce the # of parameters we pass into functions # I chose to group them into separate objects instead because it felt hacky to pass a giant args object into every function @@ -277,9 +281,9 @@ def datagen( dbdata_parent_dpath, ) dir_gen_args = EmbeddingDirGenArgs( - leading_col_tbls, + leading_col_tbls_parsed, default_sample_limit, - override_sample_limits, + override_sample_limits_parsed, file_limit, max_concurrent, no_generate_costs, @@ -332,14 +336,14 @@ class EmbeddingDatagenGenericArgs: def __init__( self, - benchmark_name, - workload_name, - scale_factor, - benchmark_config_path, - seed, - workload_path, - pristine_dbdata_snapshot_path, - dbdata_parent_dpath, + benchmark_name: str, + workload_name: str, + scale_factor: float, + benchmark_config_path: Path, + seed: int, + workload_path: Path, + pristine_dbdata_snapshot_path: Path, + dbdata_parent_dpath: Path, ): self.benchmark_name = benchmark_name self.workload_name = workload_name @@ -356,12 +360,12 @@ class EmbeddingDirGenArgs: def __init__( self, - leading_col_tbls, - default_sample_limit, - override_sample_limits, - file_limit, - max_concurrent, - no_generate_costs, + leading_col_tbls: list[str], + default_sample_limit: int, + override_sample_limits: dict[str, int], + file_limit: int, + max_concurrent: int, + no_generate_costs: bool, ): self.leading_col_tbls = leading_col_tbls self.default_sample_limit = default_sample_limit @@ -374,25 +378,25 @@ def __init__( class EmbeddingFileGenArgs: """Same comment as EmbeddingDatagenGenericArgs""" - def __init__(self, table_shape, dual_class, pad_min, rebias): + def __init__(self, table_shape: bool, dual_class: bool, pad_min: int, rebias: float): self.table_shape = table_shape self.dual_class = dual_class self.pad_min = pad_min self.rebias = rebias -def get_traindata_dir(dbgym_cfg): +def get_traindata_dir(dbgym_cfg: DBGymConfig) -> Path: return dbgym_cfg.dbgym_this_run_path / "traindata_dir" -def _gen_traindata_dir(dbgym_cfg: DBGymConfig, generic_args, dir_gen_args): +def _gen_traindata_dir(dbgym_cfg: DBGymConfig, generic_args: EmbeddingDatagenGenericArgs, dir_gen_args: EmbeddingDirGenArgs) -> None: with open_and_save(dbgym_cfg, generic_args.benchmark_config_path, "r") as f: benchmark_config = yaml.safe_load(f) - max_num_columns = benchmark_config["protox"]["max_num_columns"] - tables = benchmark_config["protox"]["tables"] - attributes = benchmark_config["protox"]["attributes"] - query_spec = benchmark_config["protox"]["query_spec"] + max_num_columns: int = benchmark_config["protox"]["max_num_columns"] + tables: list[str] = benchmark_config["protox"]["tables"] + attributes: TableAttrListMap = benchmark_config["protox"]["attributes"] + query_spec: QuerySpec = benchmark_config["protox"]["query_spec"] workload = Workload( dbgym_cfg, tables, attributes, query_spec, generic_args.workload_path, pid=None @@ -404,11 +408,7 @@ def _gen_traindata_dir(dbgym_cfg: DBGymConfig, generic_args, dir_gen_args): results = [] job_id = 0 for tbl in tables: - cols = ( - [None] - if tbl not in dir_gen_args.leading_col_tbls - else modified_attrs[tbl] - ) + cols: list[Optional[str]] = [None] if tbl not in dir_gen_args.leading_col_tbls else cast(list[Optional[str]], modified_attrs[tbl]) for colidx, col in enumerate(cols): if col is None: output = traindata_dir / tbl @@ -456,7 +456,7 @@ def _combine_traindata_dir_into_parquet( dbgym_cfg: DBGymConfig, generic_args: EmbeddingDatagenGenericArgs, file_gen_args: EmbeddingFileGenArgs, -): +) -> None: tbl_dirs = {} with open_and_save(dbgym_cfg, generic_args.benchmark_config_path, "r") as f: benchmark_config = yaml.safe_load(f) @@ -562,14 +562,10 @@ def read(file: Path) -> pd.DataFrame: link_result(dbgym_cfg, traindata_path) -def _all_subsets(ss): - return chain(*map(lambda x: combinations(ss, x), range(0, len(ss) + 1))) - - -_INDEX_SERVER_COUNTS = {} +_INDEX_SERVER_COUNTS: dict[str, int] = {} -def _fetch_server_indexes(connection): +def _fetch_server_indexes(connection: psycopg.Connection[Any]) -> None: global _INDEX_SERVER_COUNTS query = """ SELECT t.relname as table_name, i.relname as index_name @@ -596,26 +592,26 @@ def _fetch_server_indexes(connection): # return models -def _write(data, output_dir, batch_num): +def _write(data: list[dict[str, Any]], output_dir: Path, batch_num: int) -> None: df = pd.DataFrame(data) - cols = [c for c in df if "col" in c and "str" not in c] + cols = [c for c in df.columns if "col" in c and "str" not in c] df[cols] = df[cols].astype(int) - df.to_parquet(f"{output_dir}/{batch_num}.parquet") + df.to_parquet(output_dir / f"{batch_num}.parquet") del df -def _augment_query_data(workload, data): +def _augment_query_data(workload: Workload, data: dict[str, float]) -> dict[str, float]: for qstem, value in workload.queries_mix.items(): if qstem in data: data[qstem] *= value return data -def _execute_explains(cursor, batches, models): - data = {} - ou_model_data = {} +def _execute_explains(cursor: psycopg.Cursor[Any], batches: QueryBatches, models: Optional[dict[Any, Any]]) -> dict[str, float]: + data: dict[str, float] = {} + ou_model_data: dict[str, list[Any]] = {} - def acquire_model_data(q, plan): + def acquire_model_data(q: str, plan: dict[str, Any]) -> None: nonlocal ou_model_data node_tag = plan["Node Type"] node_tag = node_tag.replace(" ", "") @@ -701,15 +697,15 @@ def acquire_model_data(q, plan): return data -def _extract_refs(generate_costs, target, cursor, workload, models): +def _extract_refs(generate_costs: bool, target: Optional[str], cursor: psycopg.Cursor[Any], workload: Workload, models: Optional[dict[Any, Any]]) -> tuple[dict[str, float], dict[str, float]]: ref_qs = {} table_ref_qs = {} if generate_costs: # Get reference costs. - batches = [ + batches = QueryBatches([ (q, workload.queries[q], workload.query_aliases[q]) for q in workload.queries.keys() - ] + ]) ref_qs = _execute_explains(cursor, batches, models) ref_qs = _augment_query_data(workload, ref_qs) @@ -718,28 +714,28 @@ def _extract_refs(generate_costs, target, cursor, workload, models): table_ref_qs = ref_qs else: qs = workload.queries_for_table(target) - batches = [(q, workload.queries[q], workload.query_aliases[q]) for q in qs] + batches = QueryBatches([(q, workload.queries[q], workload.query_aliases[q]) for q in qs]) table_ref_qs = _execute_explains(cursor, batches, models) table_ref_qs = _augment_query_data(workload, table_ref_qs) return ref_qs, table_ref_qs def _produce_index_data( - dbgym_cfg, - tables, - attributes, - query_spec, - workload_path, - max_num_columns, - seed, - generate_costs, - sample_limit, - target, - leading_col, - leading_col_name, - p, - output, -): + dbgym_cfg: DBGymConfig, + tables: list[str], + attributes: TableAttrListMap, + query_spec: QuerySpec, + workload_path: Path, + max_num_columns: int, + seed: int, + generate_costs: bool, + sample_limit: int, + target: Optional[str], + leading_col: Optional[int], + leading_col_name: Optional[str], + p: int, + output: Path, +) -> None: models = None # FUTURE(oltp) @@ -748,7 +744,7 @@ def _produce_index_data( # Construct workload. workload = Workload( - dbgym_cfg, tables, attributes, query_spec, workload_path, pid=str(p) + dbgym_cfg, tables, attributes, query_spec, workload_path, pid=p ) modified_attrs = workload.column_usages() @@ -764,7 +760,7 @@ def _produce_index_data( seed=seed, rel_metadata=copy.deepcopy(modified_attrs), attributes_overwrite=copy.deepcopy(modified_attrs), - tbl_include_subsets={}, + tbl_include_subsets=TableAttrAccessSetsMap({}), index_space_aux_type=False, index_space_aux_include=False, deterministic_policy=False, @@ -793,8 +789,7 @@ def _produce_index_data( reference_qs, table_reference_qs = _extract_refs( generate_costs, target, cursor, workload, models ) - cached_refs = {} - accum_data = [] + accum_data: list[dict[str, Any]] = [] # Repeatedly... for i in range(sample_limit): @@ -811,7 +806,7 @@ def _produce_index_data( ) ia = idxs.to_action(act) - accum = { + accum: dict[str, Any] = { "table": ia.tbl_name, } if generate_costs: @@ -848,10 +843,10 @@ def _produce_index_data( else: qs_for_tbl = workload.queries_for_table(ia.tbl_name) - batches = [ + batches = QueryBatches([ (q, workload.queries[q], workload.query_aliases[q]) for q in qs_for_tbl - ] + ]) data = _execute_explains(cursor, batches, models) data = _augment_query_data(workload, data) if models is None: @@ -890,6 +885,7 @@ def _produce_index_data( for i in range(max_num_columns): accum[f"col{i}"] = 0 + assert ia.col_idxs is not None for i, col_idx in enumerate(ia.col_idxs): accum[f"col{i}"] = col_idx + 1 diff --git a/tune/protox/env/workload.py b/tune/protox/env/workload.py index e0242d65..58d27c59 100644 --- a/tune/protox/env/workload.py +++ b/tune/protox/env/workload.py @@ -79,7 +79,7 @@ def _crunch( self.tbl_filter_queries_usage: dict[TableColTuple, set[str]] = {} # Build the SQL and table usage information. - self.queries_mix = {} + self.queries_mix: dict[str, float] = {} self.query_aliases = {} self.query_usages = TableAttrListMap({t: [] for t in self.tables}) tbl_include_subsets = TableAttrAccessSetsMap( From fff33e14af8396b27af54716d47ff974914a2c29 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 00:17:54 +0000 Subject: [PATCH 54/60] fixed tune/protox/embedding/train_all.py --- tune/protox/embedding/loss.py | 4 ++-- tune/protox/embedding/train.py | 2 +- tune/protox/embedding/train_all.py | 33 +++++++++++++++--------------- tune/protox/embedding/trainer.py | 6 +++--- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/tune/protox/embedding/loss.py b/tune/protox/embedding/loss.py index ec34473e..5fbd85c6 100644 --- a/tune/protox/embedding/loss.py +++ b/tune/protox/embedding/loss.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from pytorch_metric_learning import losses # type: ignore -from pytorch_metric_learning.utils import common_functions as c_f # type: ignore +from pytorch_metric_learning import losses +from pytorch_metric_learning.utils import common_functions as c_f COST_COLUMNS = [ "quant_mult_cost_improvement", diff --git a/tune/protox/embedding/train.py b/tune/protox/embedding/train.py index 9b3dc944..0f8116e8 100644 --- a/tune/protox/embedding/train.py +++ b/tune/protox/embedding/train.py @@ -187,7 +187,7 @@ def train( train_max_concurrent: int, iterations_per_epoch: int, num_samples: int, - train_size: int, + train_size: float, start_epoch: int, batch_size: int, num_batches: int, diff --git a/tune/protox/embedding/train_all.py b/tune/protox/embedding/train_all.py index 66062593..6e0ca432 100644 --- a/tune/protox/embedding/train_all.py +++ b/tune/protox/embedding/train_all.py @@ -15,6 +15,7 @@ import ray import torch import torch.nn as nn +from torch.optim import Adam # type: ignore[attr-defined] import tqdm import yaml from pytorch_metric_learning.utils import logging_presets @@ -24,7 +25,7 @@ from ray.tune.schedulers import FIFOScheduler from ray.tune.search import ConcurrencyLimiter from ray.tune.search.hyperopt import HyperOptSearch -from sklearn.model_selection import train_test_split # type: ignore +from sklearn.model_selection import train_test_split from torch.utils.data import TensorDataset from typing_extensions import ParamSpec @@ -90,7 +91,7 @@ def fetch_index_parameters( def load_input_data( dbgym_cfg: DBGymConfig, traindata_path: Path, - train_size: int, + train_size: float, max_attrs: int, require_cost: bool, seed: int, @@ -115,7 +116,7 @@ def load_input_data( gc.collect() gc.collect() - if train_size == 1: + if train_size == 1.0: train_dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y)) del x gc.collect() @@ -126,7 +127,7 @@ def load_input_data( train_x, val_x, train_y, val_y = train_test_split( x, y, - test_size=1 - train_size, + test_size=1.0 - train_size, train_size=train_size, random_state=seed, shuffle=True, @@ -161,7 +162,7 @@ def create_vae_model( "sigmoid": nn.Sigmoid, }[config["mean_output_act"]] - torch.set_float32_matmul_precision("high") # type: ignore + torch.set_float32_matmul_precision("high") model = VAE( max_categorical=max_cat_features, input_dim=cat_input, @@ -182,7 +183,7 @@ def train_all_embeddings( dbgym_cfg: DBGymConfig, generic_args: EmbeddingTrainGenericArgs, train_all_args: EmbeddingTrainAllArgs, -): +) -> None: """ Trains all num_samples models using different samples of the hyperparameter space, writing their results to different embedding_*/ folders in the run_*/ folder @@ -226,7 +227,7 @@ def train_all_embeddings( sync_config=SyncConfig(), verbose=2, log_to_file=True, - storage_path=dbgym_cfg.cur_task_runs_path("embedding_ray_results", mkdir=True), + storage_path=str(dbgym_cfg.cur_task_runs_path("embedding_ray_results", mkdir=True)), ) resources = {"cpu": 1} @@ -270,7 +271,7 @@ def _hpo_train( dbgym_cfg: DBGymConfig, generic_args: EmbeddingTrainGenericArgs, train_all_args: EmbeddingTrainAllArgs, -): +) -> None: sys.path.append(os.fspath(dbgym_cfg.dbgym_repo_path)) # Explicitly set the number of torch threads. @@ -352,11 +353,11 @@ def _build_trainer( traindata_path: Path, trial_dpath: Path, benchmark_config_path: Path, - train_size: int, + train_size: float, workload_path: Path, - dataloader_num_workers=0, - disable_tqdm=False, -): + dataloader_num_workers: int=0, + disable_tqdm: bool=False, +) -> tuple[VAETrainer, Callable[..., Optional[dict[str, Any]]]]: max_cat_features = 0 max_attrs = 0 @@ -401,7 +402,7 @@ def _build_trainer( models = {"trunk": trunk, "embedder": model} optimizers = { - "embedder_optimizer": torch.optim.Adam( + "embedder_optimizer": Adam( model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"] ), } @@ -442,7 +443,7 @@ def _build_trainer( def clip_grad() -> None: if config["grad_clip_amount"] is not None: - torch.nn.utils.clip_grad_norm_( # type: ignore + torch.nn.utils.clip_grad_norm_( model.parameters(), config["grad_clip_amount"] ) @@ -513,9 +514,9 @@ def epoch_end(*args: P.args, **kwargs: P.kwargs) -> Optional[dict[str, Any]]: trainer.switch_eval() pbar = None if suppress else tqdm.tqdm(total=len(val_dl)) - for i, curr_batch in enumerate(val_dl): # type: ignore + for i, curr_batch in enumerate(val_dl): # Get the losses. - trainer.calculate_loss(curr_batch) # type: ignore + trainer.calculate_loss(curr_batch) if isinstance(trainer.losses["metric_loss"], torch.Tensor): total_metric_loss.append(trainer.losses["metric_loss"].item()) else: diff --git a/tune/protox/embedding/trainer.py b/tune/protox/embedding/trainer.py index e259f9c9..6b85fcba 100644 --- a/tune/protox/embedding/trainer.py +++ b/tune/protox/embedding/trainer.py @@ -6,8 +6,8 @@ import torch import tqdm from numpy.typing import NDArray -from pytorch_metric_learning import trainers # type: ignore -from pytorch_metric_learning.utils import common_functions as c_f # type: ignore +from pytorch_metric_learning import trainers +from pytorch_metric_learning.utils import common_functions as c_f from torch.utils.data import Sampler @@ -170,7 +170,7 @@ def train(self, start_epoch: int = 1, num_epochs: int = 1) -> None: if not self.disable_tqdm: pbar = tqdm.tqdm(range(self.iterations_per_epoch)) else: - pbar = range(self.iterations_per_epoch) # type: ignore + pbar = range(self.iterations_per_epoch) for self.iteration in pbar: self.forward_and_backward() From 64189a5b610e87022dcd5ab81a040243181769c2 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 00:27:43 +0000 Subject: [PATCH 55/60] fixed tune/protox/embedding/select.py --- tune/protox/embedding/select.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tune/protox/embedding/select.py b/tune/protox/embedding/select.py index db23fdd7..da9c195d 100644 --- a/tune/protox/embedding/select.py +++ b/tune/protox/embedding/select.py @@ -2,6 +2,7 @@ import os import shutil from pathlib import Path +from typing import Any, Optional import numpy as np import pandas as pd @@ -48,6 +49,8 @@ def select_best_embeddings( if select_args.flatten_idx == -1: for tup in df.itertuples(): + assert type(tup.path) is str + assert type(tup.root) is str shutil.copytree( tup.path, curated_dpath / tup.path, @@ -62,6 +65,8 @@ def select_best_embeddings( info_txt = open(curated_dpath / "info.txt", "w") for loop_i, tup in enumerate(df.itertuples()): + assert type(tup.path) is str + assert type(tup.root) is str epoch = int(str(tup.path).split("epoch")[-1]) model_dpath = curated_dpath / f"model{idx}" shutil.copytree(tup.path, model_dpath) @@ -119,7 +124,7 @@ def _load_data(dbgym_cfg: DBGymConfig, select_args: EmbeddingSelectArgs) -> Data with open(stat.parent.parent.parent / "config", "r") as f: config = json.load(f) - def recurse_set(source, target): + def recurse_set(source: dict[Any, Any], target: dict[Any, Any]) -> None: for k, v in source.items(): if isinstance(v, dict): recurse_set(v, target) @@ -154,15 +159,15 @@ def recurse_set(source, target): return data -def _attach(data, raw_data, num_limit: int=0) -> DataFrame: +def _attach(data: DataFrame, raw_data: DataFrame, num_limit: int=0) -> DataFrame: # As the group index goes up, the perf should go up (i.e., bounds should tighten) - filtered_data = {} + filtered_data: dict[tuple[float, float], DataFrame] = {} new_data = [] for tup in tqdm.tqdm(data.itertuples(), total=data.shape[0]): tup_dict = {k: getattr(tup, k) for k in data.columns} if raw_data is not None and Path(tup_dict["ranges_file"]).exists(): - def compute_dist_score(current_dists, base, upper): + def compute_dist_score(current_dists: dict[str, float], base: float, upper: float) -> float: nonlocal filtered_data key = (base, upper) if key not in filtered_data: @@ -195,14 +200,15 @@ def compute_dist_score(current_dists, base, upper): # don't use open_and_save() because we generated ranges in this run with open(tup_dict["ranges_file"], "r") as f: - errors = [] - drange = (None, None) - current_dists = {} + errors: list[float] = [] + drange: tuple[Optional[float], Optional[float]] = (None, None) + current_dists: dict[str, float] = {} for line in f: if "Generating range" in line: if len(current_dists) > 0: assert drange[0] is not None + assert drange[1] is not None errors.append( compute_dist_score(current_dists, drange[0], drange[1]) ) From 31d4330b7e315ab1c25ed0bf140ec225625cdf75 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 00:35:08 +0000 Subject: [PATCH 56/60] fixed tune/protox/embedding/analyze.py --- tune/protox/embedding/analyze.py | 38 +++++++++---------- .../space/latent_space/latent_index_space.py | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tune/protox/embedding/analyze.py b/tune/protox/embedding/analyze.py index d895e9ca..8a1cc44d 100644 --- a/tune/protox/embedding/analyze.py +++ b/tune/protox/embedding/analyze.py @@ -7,6 +7,7 @@ import shutil import time from pathlib import Path +from typing import Any, Optional import numpy as np import torch @@ -28,20 +29,21 @@ from tune.protox.embedding.trainer import StratifiedRandomSampler from tune.protox.embedding.vae import VAELoss, gen_vae_collate from tune.protox.env.space.latent_space.latent_index_space import LatentIndexSpace +from tune.protox.env.types import ProtoAction, TableAttrAccessSetsMap from tune.protox.env.workload import Workload STATS_FNAME = "stats.txt" RANGES_FNAME = "ranges.txt" -def compute_num_parts(num_samples: int): +def compute_num_parts(num_samples: int) -> int: # TODO(phw2): in the future, implement running different parts in parallel, set OMP_NUM_THREADS accordingly, and investigate the effect of having more parts # TODO(phw2): if having more parts is effective, figure out a good way to specify num_parts (can it be determined automatically or should it be a CLI arg?) # TODO(phw2): does anything bad happen if num_parts doesn't evenly divide num_samples? return 1 -def redist_trained_models(dbgym_cfg: DBGymConfig, num_parts: int): +def redist_trained_models(dbgym_cfg: DBGymConfig, num_parts: int) -> None: """ Redistribute all embeddings_*/ folders inside the run_*/ folder into num_parts subfolders """ @@ -64,7 +66,7 @@ def analyze_all_embeddings_parts( num_parts: int, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Analyze all part*/ dirs _in parallel_ """ @@ -83,7 +85,7 @@ def _analyze_embeddings_part( part_i: int, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Analyze (meaning create both stats.txt and ranges.txt) all the embedding models in the part[part_i]/ dir """ @@ -107,7 +109,7 @@ def _create_stats_for_part( part_dpath: Path, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Creates a stats.txt file inside each embeddings_*/models/epoch*/ dir inside this part*/ dir TODO(wz2): what does stats.txt contain? @@ -124,9 +126,7 @@ def _create_stats_for_part( ) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - models = itertools.chain(*[part_dpath.rglob("config")]) - models = [m for m in models] - print(f"models={models}") + models = [m for m in itertools.chain(*[part_dpath.rglob("config")])] for model_config in tqdm.tqdm(models): if ((Path(model_config).parent) / "FAILED").exists(): print("Detected failure in: ", model_config) @@ -192,7 +192,7 @@ def _create_stats_for_part( vae_loss = VAELoss(config["loss_fn"], max_attrs, max_cat_features) # Construct the accumulator. - accumulated_stats = {} + accumulated_stats: dict[str, list[Any]] = {} for class_idx in class_mapping: accumulated_stats[f"recon_{class_idx}"] = [] @@ -320,7 +320,7 @@ def _create_ranges_for_part( part_dpath: Path, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Create the ranges.txt for all models in part_dpath TODO(wz2): what does ranges.txt contain? @@ -341,7 +341,7 @@ def _create_ranges_for_embedder( embedder_fpath: Path, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Create the ranges.txt file corresponding to a specific part*/embeddings_*/models/epoch*/embedder_*.pth file """ @@ -376,9 +376,9 @@ def _create_ranges_for_embedder( lambda x: torch.nn.Sigmoid()(x) * config["output_scale"] ) - def index_noise_scale(x, n): + def index_noise_scale(x: ProtoAction, n: Optional[torch.Tensor]) -> ProtoAction: assert n is None - return torch.clamp(x, 0.0, config["output_scale"]) + return ProtoAction(torch.clamp(x, 0.0, config["output_scale"])) max_attrs, max_cat_features = fetch_vae_parameters_from_workload( workload, len(tables) @@ -392,10 +392,10 @@ def index_noise_scale(x, n): tables=tables, max_num_columns=max_num_columns, max_indexable_attributes=workload.max_indexable(), - seed=np.random.randint(1, 1e10), + seed=np.random.randint(1, int(1e10)), rel_metadata=copy.deepcopy(modified_attrs), attributes_overwrite=copy.deepcopy(modified_attrs), - tbl_include_subsets={}, + tbl_include_subsets=TableAttrAccessSetsMap({}), vae=vae, index_space_aux_type=False, index_space_aux_include=False, @@ -418,7 +418,7 @@ def index_noise_scale(x, n): ranges_fpath = epoch_dpath / RANGES_FNAME with open(ranges_fpath, "w") as f: for _ in tqdm.tqdm(range(num_segments), total=num_segments, leave=False): - classes = {} + classes: dict[str, int] = {} with torch.no_grad(): points = ( torch.rand(analyze_args.num_points_to_sample, config["latent_dim"]) @@ -444,18 +444,18 @@ def index_noise_scale(x, n): if idx_class not in classes: classes[idx_class] = 0 classes[idx_class] += 1 - classes = sorted( + sorted_classes = sorted( [(k, v) for k, v in classes.items()], key=lambda x: x[1], reverse=True ) if analyze_args.num_classes_to_keep != 0: - classes = classes[: analyze_args.num_classes_to_keep] + sorted_classes = sorted_classes[: analyze_args.num_classes_to_keep] f.write(f"Generating range {base} - {base + output_scale}\n") f.write( "\n".join( [ f"{k}: {v / analyze_args.num_points_to_sample}" - for (k, v) in classes + for (k, v) in sorted_classes ] ) ) diff --git a/tune/protox/env/space/latent_space/latent_index_space.py b/tune/protox/env/space/latent_space/latent_index_space.py index f92c98f7..9afa38b4 100644 --- a/tune/protox/env/space/latent_space/latent_index_space.py +++ b/tune/protox/env/space/latent_space/latent_index_space.py @@ -39,7 +39,7 @@ def __init__( latent_dim: int = 0, index_output_transform: Optional[Callable[[ProtoAction], ProtoAction]] = None, index_noise_scale: Optional[ - Callable[[ProtoAction, torch.Tensor], ProtoAction] + Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction] ] = None, logger: Optional[Logger] = None, ) -> None: From 10cbfe100c6556b233ddd71fd64a5bdd6c4e00ef Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 00:36:45 +0000 Subject: [PATCH 57/60] fixed other mypy bugs --- tune/protox/agent/build_trial.py | 5 +++-- tune/protox/env/space/latent_space/lsc_index_space.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tune/protox/agent/build_trial.py b/tune/protox/agent/build_trial.py index e46129df..7c84366f 100644 --- a/tune/protox/agent/build_trial.py +++ b/tune/protox/agent/build_trial.py @@ -126,8 +126,9 @@ def _modify_benchbase_config( def _gen_noise_scale( vae_config: dict[str, Any], hpo_params: dict[str, Any] -) -> Callable[[ProtoAction, torch.Tensor], ProtoAction]: - def f(p: ProtoAction, n: torch.Tensor) -> ProtoAction: +) -> Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction]: + def f(p: ProtoAction, n: Optional[torch.Tensor]) -> ProtoAction: + assert n is not None if hpo_params["scale_noise_perturb"]: return ProtoAction( torch.clamp( diff --git a/tune/protox/env/space/latent_space/lsc_index_space.py b/tune/protox/env/space/latent_space/lsc_index_space.py index e1425081..87290dcf 100644 --- a/tune/protox/env/space/latent_space/lsc_index_space.py +++ b/tune/protox/env/space/latent_space/lsc_index_space.py @@ -35,7 +35,7 @@ def __init__( latent_dim: int = 0, index_output_transform: Optional[Callable[[ProtoAction], ProtoAction]] = None, index_noise_scale: Optional[ - Callable[[ProtoAction, torch.Tensor], ProtoAction] + Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction] ] = None, logger: Optional[Logger] = None, lsc: Optional[LSC] = None, From c509893e67336d9ca62835f8cc40025d6c980385 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 00:37:05 +0000 Subject: [PATCH 58/60] format --- tune/protox/embedding/datagen.py | 71 +++++++++++++++++++++--------- tune/protox/embedding/select.py | 17 ++++--- tune/protox/embedding/train_all.py | 10 +++-- 3 files changed, 69 insertions(+), 29 deletions(-) diff --git a/tune/protox/embedding/datagen.py b/tune/protox/embedding/datagen.py index b5ff7897..aa75d280 100644 --- a/tune/protox/embedding/datagen.py +++ b/tune/protox/embedding/datagen.py @@ -39,7 +39,12 @@ ) from tune.protox.embedding.loss import COST_COLUMNS from tune.protox.env.space.primitive_space.index_space import IndexSpace -from tune.protox.env.types import QuerySpec, QueryType, TableAttrAccessSetsMap, TableAttrListMap +from tune.protox.env.types import ( + QuerySpec, + QueryType, + TableAttrAccessSetsMap, + TableAttrListMap, +) from tune.protox.env.workload import Workload from util.pg import create_psycopg_conn from util.shell import subprocess_run @@ -52,7 +57,9 @@ # pass -QueryBatches = NewType("QueryBatches", list[tuple[str, list[tuple[QueryType, str]], Any]]) +QueryBatches = NewType( + "QueryBatches", list[tuple[str, list[tuple[QueryType, str]], Any]] +) # click steup @@ -254,7 +261,9 @@ def datagen( assert False # Process the "data structure" args - leading_col_tbls_parsed: list[str] = [] if leading_col_tbls is None else leading_col_tbls.split(",") + leading_col_tbls_parsed: list[str] = ( + [] if leading_col_tbls is None else leading_col_tbls.split(",") + ) # I chose to only use the "," delimiter in override_sample_limits_str, so the dictionary is encoded as [key],[value],[key],[value] # I felt this was better than introducing a new delimiter which might conflict with the name of a table override_sample_limits_parsed: dict[str, int] = dict() @@ -378,7 +387,9 @@ def __init__( class EmbeddingFileGenArgs: """Same comment as EmbeddingDatagenGenericArgs""" - def __init__(self, table_shape: bool, dual_class: bool, pad_min: int, rebias: float): + def __init__( + self, table_shape: bool, dual_class: bool, pad_min: int, rebias: float + ): self.table_shape = table_shape self.dual_class = dual_class self.pad_min = pad_min @@ -389,7 +400,11 @@ def get_traindata_dir(dbgym_cfg: DBGymConfig) -> Path: return dbgym_cfg.dbgym_this_run_path / "traindata_dir" -def _gen_traindata_dir(dbgym_cfg: DBGymConfig, generic_args: EmbeddingDatagenGenericArgs, dir_gen_args: EmbeddingDirGenArgs) -> None: +def _gen_traindata_dir( + dbgym_cfg: DBGymConfig, + generic_args: EmbeddingDatagenGenericArgs, + dir_gen_args: EmbeddingDirGenArgs, +) -> None: with open_and_save(dbgym_cfg, generic_args.benchmark_config_path, "r") as f: benchmark_config = yaml.safe_load(f) @@ -408,7 +423,11 @@ def _gen_traindata_dir(dbgym_cfg: DBGymConfig, generic_args: EmbeddingDatagenGen results = [] job_id = 0 for tbl in tables: - cols: list[Optional[str]] = [None] if tbl not in dir_gen_args.leading_col_tbls else cast(list[Optional[str]], modified_attrs[tbl]) + cols: list[Optional[str]] = ( + [None] + if tbl not in dir_gen_args.leading_col_tbls + else cast(list[Optional[str]], modified_attrs[tbl]) + ) for colidx, col in enumerate(cols): if col is None: output = traindata_dir / tbl @@ -607,7 +626,9 @@ def _augment_query_data(workload: Workload, data: dict[str, float]) -> dict[str, return data -def _execute_explains(cursor: psycopg.Cursor[Any], batches: QueryBatches, models: Optional[dict[Any, Any]]) -> dict[str, float]: +def _execute_explains( + cursor: psycopg.Cursor[Any], batches: QueryBatches, models: Optional[dict[Any, Any]] +) -> dict[str, float]: data: dict[str, float] = {} ou_model_data: dict[str, list[Any]] = {} @@ -697,15 +718,23 @@ def acquire_model_data(q: str, plan: dict[str, Any]) -> None: return data -def _extract_refs(generate_costs: bool, target: Optional[str], cursor: psycopg.Cursor[Any], workload: Workload, models: Optional[dict[Any, Any]]) -> tuple[dict[str, float], dict[str, float]]: +def _extract_refs( + generate_costs: bool, + target: Optional[str], + cursor: psycopg.Cursor[Any], + workload: Workload, + models: Optional[dict[Any, Any]], +) -> tuple[dict[str, float], dict[str, float]]: ref_qs = {} table_ref_qs = {} if generate_costs: # Get reference costs. - batches = QueryBatches([ - (q, workload.queries[q], workload.query_aliases[q]) - for q in workload.queries.keys() - ]) + batches = QueryBatches( + [ + (q, workload.queries[q], workload.query_aliases[q]) + for q in workload.queries.keys() + ] + ) ref_qs = _execute_explains(cursor, batches, models) ref_qs = _augment_query_data(workload, ref_qs) @@ -714,7 +743,9 @@ def _extract_refs(generate_costs: bool, target: Optional[str], cursor: psycopg.C table_ref_qs = ref_qs else: qs = workload.queries_for_table(target) - batches = QueryBatches([(q, workload.queries[q], workload.query_aliases[q]) for q in qs]) + batches = QueryBatches( + [(q, workload.queries[q], workload.query_aliases[q]) for q in qs] + ) table_ref_qs = _execute_explains(cursor, batches, models) table_ref_qs = _augment_query_data(workload, table_ref_qs) return ref_qs, table_ref_qs @@ -743,9 +774,7 @@ def _produce_index_data( # models = load_ou_models(model_dir) # Construct workload. - workload = Workload( - dbgym_cfg, tables, attributes, query_spec, workload_path, pid=p - ) + workload = Workload(dbgym_cfg, tables, attributes, query_spec, workload_path, pid=p) modified_attrs = workload.column_usages() np.random.seed(seed) @@ -843,10 +872,12 @@ def _produce_index_data( else: qs_for_tbl = workload.queries_for_table(ia.tbl_name) - batches = QueryBatches([ - (q, workload.queries[q], workload.query_aliases[q]) - for q in qs_for_tbl - ]) + batches = QueryBatches( + [ + (q, workload.queries[q], workload.query_aliases[q]) + for q in qs_for_tbl + ] + ) data = _execute_explains(cursor, batches, models) data = _augment_query_data(workload, data) if models is None: diff --git a/tune/protox/embedding/select.py b/tune/protox/embedding/select.py index da9c195d..613730b6 100644 --- a/tune/protox/embedding/select.py +++ b/tune/protox/embedding/select.py @@ -6,8 +6,8 @@ import numpy as np import pandas as pd -from pandas import DataFrame import tqdm +from pandas import DataFrame from misc.utils import DBGymConfig, default_embedder_dname, link_result from tune.protox.embedding.analyze import RANGES_FNAME, STATS_FNAME @@ -159,7 +159,7 @@ def recurse_set(source: dict[Any, Any], target: dict[Any, Any]) -> None: return data -def _attach(data: DataFrame, raw_data: DataFrame, num_limit: int=0) -> DataFrame: +def _attach(data: DataFrame, raw_data: DataFrame, num_limit: int = 0) -> DataFrame: # As the group index goes up, the perf should go up (i.e., bounds should tighten) filtered_data: dict[tuple[float, float], DataFrame] = {} new_data = [] @@ -167,7 +167,9 @@ def _attach(data: DataFrame, raw_data: DataFrame, num_limit: int=0) -> DataFrame tup_dict = {k: getattr(tup, k) for k in data.columns} if raw_data is not None and Path(tup_dict["ranges_file"]).exists(): - def compute_dist_score(current_dists: dict[str, float], base: float, upper: float) -> float: + def compute_dist_score( + current_dists: dict[str, float], base: float, upper: float + ) -> float: nonlocal filtered_data key = (base, upper) if key not in filtered_data: @@ -219,7 +221,10 @@ def compute_dist_score(current_dists: dict[str, float], base: float, upper: floa if drange[0] is None: drange = (1.0 - tup_dict["bias_separation"], 1.01) else: - drange = (drange[0] - tup_dict["bias_separation"], drange[0]) + drange = ( + drange[0] - tup_dict["bias_separation"], + drange[0], + ) current_dists = {} else: @@ -230,7 +235,9 @@ def compute_dist_score(current_dists: dict[str, float], base: float, upper: floa if len(current_dists) > 0: # Put the error in. errors.append( - compute_dist_score(current_dists, 0.0, tup_dict["bias_separation"]) + compute_dist_score( + current_dists, 0.0, tup_dict["bias_separation"] + ) ) tup_dict["idx_class_errors"] = ",".join( diff --git a/tune/protox/embedding/train_all.py b/tune/protox/embedding/train_all.py index 6e0ca432..9f0aed3a 100644 --- a/tune/protox/embedding/train_all.py +++ b/tune/protox/embedding/train_all.py @@ -15,7 +15,6 @@ import ray import torch import torch.nn as nn -from torch.optim import Adam # type: ignore[attr-defined] import tqdm import yaml from pytorch_metric_learning.utils import logging_presets @@ -26,6 +25,7 @@ from ray.tune.search import ConcurrencyLimiter from ray.tune.search.hyperopt import HyperOptSearch from sklearn.model_selection import train_test_split +from torch.optim import Adam # type: ignore[attr-defined] from torch.utils.data import TensorDataset from typing_extensions import ParamSpec @@ -227,7 +227,9 @@ def train_all_embeddings( sync_config=SyncConfig(), verbose=2, log_to_file=True, - storage_path=str(dbgym_cfg.cur_task_runs_path("embedding_ray_results", mkdir=True)), + storage_path=str( + dbgym_cfg.cur_task_runs_path("embedding_ray_results", mkdir=True) + ), ) resources = {"cpu": 1} @@ -355,8 +357,8 @@ def _build_trainer( benchmark_config_path: Path, train_size: float, workload_path: Path, - dataloader_num_workers: int=0, - disable_tqdm: bool=False, + dataloader_num_workers: int = 0, + disable_tqdm: bool = False, ) -> tuple[VAETrainer, Callable[..., Optional[dict[str, Any]]]]: max_cat_features = 0 max_attrs = 0 From ba9f27c76440409deda261f74284be6c8c22695f Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 14:53:23 +0000 Subject: [PATCH 59/60] fixed create_sqlalchemy_conn using the psycopg connstr --- util/pg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/util/pg.py b/util/pg.py index 9d8a370e..5358b937 100644 --- a/util/pg.py +++ b/util/pg.py @@ -63,8 +63,8 @@ def create_psycopg_conn(pgport: int = DEFAULT_POSTGRES_PORT) -> psycopg.Connecti def create_sqlalchemy_conn( pgport: int = DEFAULT_POSTGRES_PORT, ) -> sqlalchemy.Connection: - connstr = get_connstr(use_psycopg=True, pgport=pgport) - engine = create_engine( + connstr = get_connstr(use_psycopg=False, pgport=pgport) + engine: sqlalchemy.Engine = create_engine( connstr, execution_options={"isolation_level": "AUTOCOMMIT"}, ) From 67d2cf5e7453f6c380622dde76c91d2caf798c95 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 17:31:29 +0000 Subject: [PATCH 60/60] step_post_execute() now returns an Optional[float] for reward --- tune/protox/env/mqo/mqo_wrapper.py | 5 ++++- tune/protox/env/pg_env.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tune/protox/env/mqo/mqo_wrapper.py b/tune/protox/env/mqo/mqo_wrapper.py index 6f39a43a..965f7952 100644 --- a/tune/protox/env/mqo/mqo_wrapper.py +++ b/tune/protox/env/mqo/mqo_wrapper.py @@ -329,7 +329,10 @@ def transmute( [best_observed_holon_action] ) - return self.unwrapped.step_post_execute(success, action, info) + obs, reward, term, trunc, info = self.step_post_execute(success, action, info) + # Since we called step_post_execute() with soft=False, we expect infos[1] (reward) to not be None. + assert reward is not None + return (obs, reward, term, trunc, info) def reset(self, *args: Any, **kwargs: Any) -> tuple[Any, EnvInfoDict]: # type: ignore assert isinstance(self.unwrapped, PostgresEnv) diff --git a/tune/protox/env/pg_env.py b/tune/protox/env/pg_env.py index b22a9521..de298170 100644 --- a/tune/protox/env/pg_env.py +++ b/tune/protox/env/pg_env.py @@ -321,8 +321,17 @@ def step_post_execute( success: bool, action: HolonAction, info: EnvInfoDict, + # If "soft" is true, it means we're calling step_post_execute() from reset(). If it's false, it means we're calling step_post_execute() from step(). soft: bool = False, - ) -> tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, Optional[float], bool, bool, EnvInfoDict]: + # If we're calling step_post_execute() from reset(), we expect info["metric"] and info["reward"] to be None. + if not soft: + assert info["reward"] is not None + assert info["metric"] is not None + else: + assert info["reward"] is None + assert info["metric"] is None + if self.workload.oltp_workload and self.horizon > 1: # If horizon = 1, then we're going to reset anyways. So easier to just untar the original archive. # Restore the crisp and clean snapshot. @@ -357,7 +366,6 @@ def step_post_execute( if not soft: self.current_step = self.current_step + 1 self.current_state = next_state - assert info["reward"] is not None return ( self.current_state, info["reward"], @@ -372,7 +380,10 @@ def step( # type: ignore assert self.tuning_mode != TuningMode.REPLAY success, info = self.step_before_execution(action) success, info = self.step_execute(success, [("PerQuery", action)], info) - return self.step_post_execute(success, action, info) + obs, reward, term, trunc, info = self.step_post_execute(success, action, info) + # Since we called step_post_execute() with soft=False, we expect infos[1] (reward) to not be None. + assert reward is not None + return (obs, reward, term, trunc, info) @time_record("shift_state") def shift_state(