Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static type checking #39

Merged
merged 60 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
eae129d
added __init__.py files to all folders
wangpatrick57 Sep 1, 2024
5eaf4bb
fixed some simple type errors
wangpatrick57 Sep 1, 2024
5b94622
fixed some more errors
wangpatrick57 Sep 1, 2024
3500ba8
fixed tune/protox/env/logger.py
wangpatrick57 Sep 1, 2024
4bf44ad
fixed task.py
wangpatrick57 Sep 1, 2024
36eb67a
fixed tune/cli.py
wangpatrick57 Sep 1, 2024
2c32b4f
fixed tune/protox/cli.py
wangpatrick57 Sep 1, 2024
06e1e40
fixed tune/protox/agent/cli.py
wangpatrick57 Sep 1, 2024
3081e75
fixed tune/protox/agent/tune.py
wangpatrick57 Sep 1, 2024
5b5d50b
fixed tune/protox/agent/hpo.py
wangpatrick57 Sep 1, 2024
6a0468a
fixed tune/protox/agent/replay.py
wangpatrick57 Sep 2, 2024
9318b8c
added mypy to req.txt
wangpatrick57 Sep 2, 2024
955216b
fixed a few scattered errors
wangpatrick57 Sep 2, 2024
e054e82
fixed tune/protox/embedding/train_args.py
wangpatrick57 Sep 2, 2024
0c81612
fixed scripts/read_parquet.py
wangpatrick57 Sep 2, 2024
8fc0af5
fixed util/shell.py
wangpatrick57 Sep 2, 2024
b626bb4
fixed tune/protox/env/space/primitive/index.py
wangpatrick57 Sep 2, 2024
b0fbba7
fixed tune/protox/embedding/vae.py
wangpatrick57 Sep 2, 2024
0eeccc5
fixed misc/utils.py
wangpatrick57 Sep 2, 2024
921ac5d
fixed tune/protox/agent/hpo.py
wangpatrick57 Sep 2, 2024
a6f2240
fixed tune/protox/agent/replay.py
wangpatrick57 Sep 2, 2024
64ddbd9
fixed tune/protox/agent/build_trial.py
wangpatrick57 Sep 2, 2024
d9d6996
fixed tune/protox/embedding/cli.py
wangpatrick57 Sep 2, 2024
70c2d19
fixed tune/protox/embedding/train.py
wangpatrick57 Sep 2, 2024
5f52029
now ignoring errors in embedding/
wangpatrick57 Sep 2, 2024
6228088
fixed util/pg.py
wangpatrick57 Sep 2, 2024
96338bc
fixed tune/protox/env/mqo/mqo_wrapper.py
wangpatrick57 Sep 2, 2024
6996e2c
fixed tune/protox/env/pg_env.py
wangpatrick57 Sep 2, 2024
20d83bf
fixed tune/protox/agent/replay.py
wangpatrick57 Sep 2, 2024
8c59fde
fixed tune/protox/tests/test_index_space.py
wangpatrick57 Sep 2, 2024
4d1d118
fixed tune/protox/tests/test_workload.py
wangpatrick57 Sep 2, 2024
083b09e
fixed tune/protox/agent/wolp/policies.py
wangpatrick57 Sep 2, 2024
45d93af
fixed tune/protox/env/space/state/structure.py
wangpatrick57 Sep 2, 2024
3eae78f
fixed tune/protox/env/workload.py
wangpatrick57 Sep 2, 2024
3303148
fixed tune/protox/env/space/holon_space.py
wangpatrick57 Sep 2, 2024
04aae88
fixed tune/protox/agent/off_policy_algorithm.py
wangpatrick57 Sep 2, 2024
bb1043f
fixed tune/protox/env/util/pg_conn.py
wangpatrick57 Sep 2, 2024
8d7d109
fixed tune/protox/env/logger.py
wangpatrick57 Sep 2, 2024
cb2d068
fixed tune/protox/tests/test_workload_utils.py
wangpatrick57 Sep 2, 2024
47c762e
fixed tune/protox/env/util/workload_analysis.py
wangpatrick57 Sep 2, 2024
2f68f7d
fixed tune/protox/tests/test_primitive.py
wangpatrick57 Sep 2, 2024
09bdf39
fixed benchmark/tpch/cli.py
wangpatrick57 Sep 2, 2024
c80f814
fixed dbms/postgres/cli.py
wangpatrick57 Sep 2, 2024
e39c151
fixed manage/tests/test_clean.py
wangpatrick57 Sep 2, 2024
b93d3f4
fixed benchmark/tpch/load_info.py
wangpatrick57 Sep 2, 2024
fcf6894
fixed manage/cli.py
wangpatrick57 Sep 2, 2024
4273cd8
replaced List/Dict/Set with list/dict/set
wangpatrick57 Sep 2, 2024
e62f612
reformatted
wangpatrick57 Sep 2, 2024
bada733
added mypy to CI
wangpatrick57 Sep 2, 2024
e094409
small type error
wangpatrick57 Sep 2, 2024
5ec9ff6
fixed issues around psycopg and sqlalchemy conn
wangpatrick57 Sep 2, 2024
632d2a2
made a few fixes to select.py
wangpatrick57 Sep 2, 2024
181c6e9
fixed tune/protox/embedding/datagen.py
wangpatrick57 Sep 3, 2024
fff33e1
fixed tune/protox/embedding/train_all.py
wangpatrick57 Sep 3, 2024
64189a5
fixed tune/protox/embedding/select.py
wangpatrick57 Sep 3, 2024
31d4330
fixed tune/protox/embedding/analyze.py
wangpatrick57 Sep 3, 2024
10cbfe1
fixed other mypy bugs
wangpatrick57 Sep 3, 2024
c509893
format
wangpatrick57 Sep 3, 2024
ba9f27c
fixed create_sqlalchemy_conn using the psycopg connstr
wangpatrick57 Sep 3, 2024
67d2cf5
step_post_execute() now returns an Optional[float] for reward
wangpatrick57 Sep 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/tests_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ jobs:
run: |
./scripts/check_format.sh

- name: Static type checking
run: |
mypy --config-file scripts/mypy.ini .

- name: Run unit tests
run: |
. "$HOME/.cargo/env"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__pycache__/
.mypy_cache/
.conda/
.idea/
test_clean_scratchspace/
Expand Down
Empty file added benchmark/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@click.group(name="benchmark")
@click.pass_obj
def benchmark_group(dbgym_cfg: DBGymConfig):
def benchmark_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("benchmark")


Expand Down
Empty file added benchmark/tpch/__init__.py
Empty file.
20 changes: 10 additions & 10 deletions benchmark/tpch/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import shutil
from pathlib import Path

import click

Expand All @@ -10,7 +9,6 @@
link_result,
workload_name_fn,
)
from util.pg import *
from util.shell import subprocess_run

benchmark_tpch_logger = logging.getLogger("benchmark/tpch")
Expand All @@ -19,7 +17,7 @@

@click.group(name="tpch")
@click.pass_obj
def tpch_group(dbgym_cfg: DBGymConfig):
def tpch_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("tpch")


Expand All @@ -28,7 +26,7 @@ def tpch_group(dbgym_cfg: DBGymConfig):
@click.pass_obj
# The reason generate data is separate from create dbdata is because generate-data is generic
# to all DBMSs while create dbdata is specific to a single DBMS.
def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float):
def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None:
_clone(dbgym_cfg)
_generate_data(dbgym_cfg, scale_factor)

Expand Down Expand Up @@ -59,7 +57,7 @@ def tpch_workload(
seed_end: int,
query_subset: str,
scale_factor: float,
):
) -> None:
assert (
seed_start <= seed_end
), f"seed_start ({seed_start}) must be <= seed_end ({seed_end})"
Expand All @@ -72,7 +70,7 @@ def _get_queries_dname(seed: int, scale_factor: float) -> str:
return f"queries_{seed}_sf{get_scale_factor_string(scale_factor)}"


def _clone(dbgym_cfg: DBGymConfig):
def _clone(dbgym_cfg: DBGymConfig) -> None:
expected_symlink_dpath = (
dbgym_cfg.cur_symlinks_build_path(mkdir=True) / "tpch-kit.link"
)
Expand Down Expand Up @@ -102,7 +100,7 @@ def _get_tpch_kit_dpath(dbgym_cfg: DBGymConfig) -> Path:

def _generate_queries(
dbgym_cfg: DBGymConfig, seed_start: int, seed_end: int, scale_factor: float
):
) -> None:
tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg)
data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True)
benchmark_tpch_logger.info(
Expand Down Expand Up @@ -132,7 +130,7 @@ def _generate_queries(
)


def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float):
def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None:
tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg)
data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True)
expected_tables_symlink_dpath = (
Expand Down Expand Up @@ -162,7 +160,7 @@ def _generate_workload(
seed_end: int,
query_subset: str,
scale_factor: float,
):
) -> None:
symlink_data_dpath = dbgym_cfg.cur_symlinks_data_path(mkdir=True)
workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset)
expected_workload_symlink_dpath = symlink_data_dpath / (workload_name + ".link")
Expand All @@ -177,6 +175,8 @@ def _generate_workload(
queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 0]
elif query_subset == "odd":
queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 1]
else:
assert False

with open(real_dpath / "order.txt", "w") as f:
for seed in range(seed_start, seed_end + 1):
Expand Down
9 changes: 6 additions & 3 deletions benchmark/tpch/load_info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from pathlib import Path
from typing import Optional

from dbms.load_info_base_class import LoadInfoBaseClass
from misc.utils import DBGymConfig, get_scale_factor_string

Expand Down Expand Up @@ -55,11 +58,11 @@ def __init__(self, dbgym_cfg: DBGymConfig, scale_factor: float):
table_fpath = tables_dpath / f"{table}.tbl"
self._tables_and_fpaths.append((table, table_fpath))

def get_schema_fpath(self):
def get_schema_fpath(self) -> Path:
return self._schema_fpath

def get_tables_and_fpaths(self):
def get_tables_and_fpaths(self) -> list[tuple[str, Path]]:
return self._tables_and_fpaths

def get_constraints_fpath(self):
def get_constraints_fpath(self) -> Optional[Path]:
return self._constraints_fpath
Empty file added dbms/__init__.py
Empty file.
3 changes: 2 additions & 1 deletion dbms/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import click

from dbms.postgres.cli import postgres_group
from misc.utils import DBGymConfig


@click.group(name="dbms")
@click.pass_obj
def dbms_group(dbgym_cfg):
def dbms_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("dbms")


Expand Down
10 changes: 7 additions & 3 deletions dbms/load_info_base_class.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from pathlib import Path
from typing import Optional


class LoadInfoBaseClass:
"""
A base class for providing info for DBMSs to load the data of a benchmark
When copying these functions to a specific benchmark's load_info.py file, don't
copy the comments or type annotations or else they might become out of sync.
"""

def get_schema_fpath(self) -> str:
def get_schema_fpath(self) -> Path:
raise NotImplemented

def get_tables_and_fpaths(self) -> list[(str, str)]:
def get_tables_and_fpaths(self) -> list[tuple[str, Path]]:
raise NotImplemented

# If the subclassing benchmark does not have constraints, you can return None here
def get_constraints_fpath(self) -> str | None:
def get_constraints_fpath(self) -> Optional[Path]:
raise NotImplemented
Empty file added dbms/postgres/__init__.py
Empty file.
43 changes: 24 additions & 19 deletions dbms/postgres/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import shutil
import subprocess
from pathlib import Path
from typing import Optional

import click
from sqlalchemy import Connection
import sqlalchemy

from benchmark.tpch.load_info import TpchLoadInfo
from dbms.load_info_base_class import LoadInfoBaseClass
Expand All @@ -35,9 +36,9 @@
DEFAULT_POSTGRES_DBNAME,
DEFAULT_POSTGRES_PORT,
SHARED_PRELOAD_LIBRARIES,
conn_execute,
create_conn,
create_sqlalchemy_conn,
sql_file_execute,
sqlalchemy_conn_execute,
)
from util.shell import subprocess_run

Expand All @@ -47,7 +48,7 @@

@click.group(name="postgres")
@click.pass_obj
def postgres_group(dbgym_cfg: DBGymConfig):
def postgres_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("postgres")


Expand All @@ -61,7 +62,7 @@ def postgres_group(dbgym_cfg: DBGymConfig):
is_flag=True,
help="Include this flag to rebuild Postgres even if it already exists.",
)
def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool):
def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool) -> None:
_build_repo(dbgym_cfg, rebuild)


Expand Down Expand Up @@ -94,14 +95,14 @@ def postgres_dbdata(
dbgym_cfg: DBGymConfig,
benchmark_name: str,
scale_factor: float,
pgbin_path: Path,
pgbin_path: Optional[Path],
intended_dbdata_hardware: str,
dbdata_parent_dpath: Path,
):
dbdata_parent_dpath: Optional[Path],
) -> None:
# Set args to defaults programmatically (do this before doing anything else in the function)
if pgbin_path == None:
if pgbin_path is None:
pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path)
if dbdata_parent_dpath == None:
if dbdata_parent_dpath is None:
dbdata_parent_dpath = default_dbdata_parent_dpath(
dbgym_cfg.dbgym_workspace_path
)
Expand Down Expand Up @@ -138,7 +139,7 @@ def _get_repo_symlink_path(dbgym_cfg: DBGymConfig) -> Path:
return dbgym_cfg.cur_symlinks_build_path("repo.link")


def _build_repo(dbgym_cfg: DBGymConfig, rebuild):
def _build_repo(dbgym_cfg: DBGymConfig, rebuild: bool) -> None:
expected_repo_symlink_dpath = _get_repo_symlink_path(dbgym_cfg)
if not rebuild and expected_repo_symlink_dpath.exists():
dbms_postgres_logger.info(
Expand Down Expand Up @@ -209,7 +210,7 @@ def _create_dbdata(
dbms_postgres_logger.info(f"Created dbdata in {dbdata_tgz_symlink_path}")


def _generic_dbdata_setup(dbgym_cfg: DBGymConfig):
def _generic_dbdata_setup(dbgym_cfg: DBGymConfig) -> None:
# get necessary vars
pgbin_real_dpath = _get_pgbin_symlink_path(dbgym_cfg).resolve()
assert pgbin_real_dpath.exists()
Expand Down Expand Up @@ -247,8 +248,8 @@ def _generic_dbdata_setup(dbgym_cfg: DBGymConfig):

def _load_benchmark_into_dbdata(
dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float
):
with create_conn(use_psycopg=False) as conn:
) -> None:
with create_sqlalchemy_conn() as conn:
if benchmark_name == "tpch":
load_info = TpchLoadInfo(dbgym_cfg, scale_factor)
else:
Expand All @@ -260,23 +261,27 @@ def _load_benchmark_into_dbdata(


def _load_into_dbdata(
dbgym_cfg: DBGymConfig, conn: Connection, load_info: LoadInfoBaseClass
):
dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, load_info: LoadInfoBaseClass
) -> None:
sql_file_execute(dbgym_cfg, conn, load_info.get_schema_fpath())

# truncate all tables first before even loading a single one
for table, _ in load_info.get_tables_and_fpaths():
conn_execute(conn, f"TRUNCATE {table} CASCADE")
sqlalchemy_conn_execute(conn, f"TRUNCATE {table} CASCADE")
# then, load the tables
for table, table_fpath in load_info.get_tables_and_fpaths():
with open_and_save(dbgym_cfg, table_fpath, "r") as table_csv:
with conn.connection.dbapi_connection.cursor() as cur:
assert conn.connection.dbapi_connection is not None
cur = conn.connection.dbapi_connection.cursor()
try:
with cur.copy(f"COPY {table} FROM STDIN CSV DELIMITER '|'") as copy:
while data := table_csv.read(8192):
copy.write(data)
finally:
cur.close()

constraints_fpath = load_info.get_constraints_fpath()
if constraints_fpath != None:
if constraints_fpath is not None:
sql_file_execute(dbgym_cfg, conn, constraints_fpath)


Expand Down
18 changes: 12 additions & 6 deletions dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
absl-py==2.1.0
aiosignal==1.3.1
astunparse==1.6.3
async-timeout==4.0.3
attrs==23.2.0
black==24.2.0
cachetools==5.3.2
Expand All @@ -25,7 +26,7 @@ google-auth-oauthlib==1.0.0
google-pasta==0.2.0
greenlet==3.0.3
grpcio==1.60.0
gymnasium==0.28.1
gymnasium==0.29.1
h5py==3.10.0
hyperopt==0.2.7
idna==3.6
Expand All @@ -44,6 +45,7 @@ MarkupSafe==2.1.4
ml-dtypes==0.2.0
mpmath==1.3.0
msgpack==1.0.7
mypy==1.11.2
mypy-extensions==1.0.0
networkx==3.2.1
numpy==1.26.3
Expand All @@ -56,7 +58,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu11==8.5.0.96
nvidia-cudnn-cu12==8.9.2.26
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu11==10.9.0.58
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu11==10.2.10.91
Expand All @@ -66,14 +68,15 @@ nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu11==11.7.4.91
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu11==2.14.3
nvidia-nccl-cu12==2.19.3
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu11==11.7.91
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.2
pandas==2.2.0
pandas-stubs==2.2.2.240807
pathspec==0.12.1
pglast==6.2
platformdirs==4.2.0
Expand All @@ -92,6 +95,7 @@ pytz==2023.4
PyYAML==6.0.1
ray==2.9.3
record-keeper==0.9.32
redis==5.0.3
referencing==0.33.0
requests==2.31.0
requests-oauthlib==1.3.1
Expand All @@ -112,14 +116,16 @@ tensorflow-io-gcs-filesystem==0.36.0
termcolor==2.4.0
threadpoolctl==3.2.0
tomli==2.0.1
torch==2.0.0
torch==2.4.0
tqdm==4.66.1
triton==2.0.0
triton==3.0.0
types-python-dateutil==2.9.0.20240821
types-pytz==2024.1.0.20240417
types-PyYAML==6.0.12.20240808
typing_extensions==4.9.0
tzdata==2023.4
urllib3==2.2.0
virtualenv==20.25.0
Werkzeug==3.0.1
wrapt==1.14.1
zipp==3.17.0
redis==5.0.3
Loading
Loading