Skip to content

Commit

Permalink
Static type checking (#39)
Browse files Browse the repository at this point in the history
**Summary**: added `mypy --strict` to CI and fixed all type errors (~900
total) found by it. This also found and fixed some _logical_ errors in
the code.

**Demo**:
![Screenshot 2024-09-03 at 13 53
11](https://github.com/user-attachments/assets/d68c2828-644f-4fa3-a7d6-3e683da94ff3)
[Passing
CI](https://github.com/cmu-db/dbgym/actions/runs/10687793830/job/29626172277)

**Details**:
* Readability is much better now that everything is typed.
* `mypy` found that we were using a feature from torch 2.4.0 so I
upgraded torch from 2.0.0 -> 2.4.0.
* Some type errors were because of dead code. These were removed.
* Some type errors showed places where we were using the wrong type.
* Many asserts were added for `Optional[...]` types.
* Fixed a previously confusing situation around mixing
`psycopg.Connection` and `sqlalchemy.Connection` in
`pg.py:create_conn()`.
  • Loading branch information
wangpatrick57 authored Sep 5, 2024
1 parent ef24dc1 commit ac849f8
Show file tree
Hide file tree
Showing 83 changed files with 1,372 additions and 1,144 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/tests_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ jobs:
run: |
./scripts/check_format.sh
- name: Static type checking
run: |
mypy --config-file scripts/mypy.ini .
- name: Run unit tests
run: |
. "$HOME/.cargo/env"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__pycache__/
.mypy_cache/
.conda/
.idea/
test_clean_scratchspace/
Expand Down
Empty file added benchmark/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@click.group(name="benchmark")
@click.pass_obj
def benchmark_group(dbgym_cfg: DBGymConfig):
def benchmark_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("benchmark")


Expand Down
Empty file added benchmark/tpch/__init__.py
Empty file.
20 changes: 10 additions & 10 deletions benchmark/tpch/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import shutil
from pathlib import Path

import click

Expand All @@ -10,7 +9,6 @@
link_result,
workload_name_fn,
)
from util.pg import *
from util.shell import subprocess_run

benchmark_tpch_logger = logging.getLogger("benchmark/tpch")
Expand All @@ -19,7 +17,7 @@

@click.group(name="tpch")
@click.pass_obj
def tpch_group(dbgym_cfg: DBGymConfig):
def tpch_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("tpch")


Expand All @@ -28,7 +26,7 @@ def tpch_group(dbgym_cfg: DBGymConfig):
@click.pass_obj
# The reason generate data is separate from create dbdata is because generate-data is generic
# to all DBMSs while create dbdata is specific to a single DBMS.
def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float):
def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None:
_clone(dbgym_cfg)
_generate_data(dbgym_cfg, scale_factor)

Expand Down Expand Up @@ -59,7 +57,7 @@ def tpch_workload(
seed_end: int,
query_subset: str,
scale_factor: float,
):
) -> None:
assert (
seed_start <= seed_end
), f"seed_start ({seed_start}) must be <= seed_end ({seed_end})"
Expand All @@ -72,7 +70,7 @@ def _get_queries_dname(seed: int, scale_factor: float) -> str:
return f"queries_{seed}_sf{get_scale_factor_string(scale_factor)}"


def _clone(dbgym_cfg: DBGymConfig):
def _clone(dbgym_cfg: DBGymConfig) -> None:
expected_symlink_dpath = (
dbgym_cfg.cur_symlinks_build_path(mkdir=True) / "tpch-kit.link"
)
Expand Down Expand Up @@ -102,7 +100,7 @@ def _get_tpch_kit_dpath(dbgym_cfg: DBGymConfig) -> Path:

def _generate_queries(
dbgym_cfg: DBGymConfig, seed_start: int, seed_end: int, scale_factor: float
):
) -> None:
tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg)
data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True)
benchmark_tpch_logger.info(
Expand Down Expand Up @@ -132,7 +130,7 @@ def _generate_queries(
)


def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float):
def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None:
tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg)
data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True)
expected_tables_symlink_dpath = (
Expand Down Expand Up @@ -162,7 +160,7 @@ def _generate_workload(
seed_end: int,
query_subset: str,
scale_factor: float,
):
) -> None:
symlink_data_dpath = dbgym_cfg.cur_symlinks_data_path(mkdir=True)
workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset)
expected_workload_symlink_dpath = symlink_data_dpath / (workload_name + ".link")
Expand All @@ -177,6 +175,8 @@ def _generate_workload(
queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 0]
elif query_subset == "odd":
queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 1]
else:
assert False

with open(real_dpath / "order.txt", "w") as f:
for seed in range(seed_start, seed_end + 1):
Expand Down
9 changes: 6 additions & 3 deletions benchmark/tpch/load_info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from pathlib import Path
from typing import Optional

from dbms.load_info_base_class import LoadInfoBaseClass
from misc.utils import DBGymConfig, get_scale_factor_string

Expand Down Expand Up @@ -55,11 +58,11 @@ def __init__(self, dbgym_cfg: DBGymConfig, scale_factor: float):
table_fpath = tables_dpath / f"{table}.tbl"
self._tables_and_fpaths.append((table, table_fpath))

def get_schema_fpath(self):
def get_schema_fpath(self) -> Path:
return self._schema_fpath

def get_tables_and_fpaths(self):
def get_tables_and_fpaths(self) -> list[tuple[str, Path]]:
return self._tables_and_fpaths

def get_constraints_fpath(self):
def get_constraints_fpath(self) -> Optional[Path]:
return self._constraints_fpath
Empty file added dbms/__init__.py
Empty file.
3 changes: 2 additions & 1 deletion dbms/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import click

from dbms.postgres.cli import postgres_group
from misc.utils import DBGymConfig


@click.group(name="dbms")
@click.pass_obj
def dbms_group(dbgym_cfg):
def dbms_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("dbms")


Expand Down
10 changes: 7 additions & 3 deletions dbms/load_info_base_class.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from pathlib import Path
from typing import Optional


class LoadInfoBaseClass:
"""
A base class for providing info for DBMSs to load the data of a benchmark
When copying these functions to a specific benchmark's load_info.py file, don't
copy the comments or type annotations or else they might become out of sync.
"""

def get_schema_fpath(self) -> str:
def get_schema_fpath(self) -> Path:
raise NotImplemented

def get_tables_and_fpaths(self) -> list[(str, str)]:
def get_tables_and_fpaths(self) -> list[tuple[str, Path]]:
raise NotImplemented

# If the subclassing benchmark does not have constraints, you can return None here
def get_constraints_fpath(self) -> str | None:
def get_constraints_fpath(self) -> Optional[Path]:
raise NotImplemented
Empty file added dbms/postgres/__init__.py
Empty file.
43 changes: 24 additions & 19 deletions dbms/postgres/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import shutil
import subprocess
from pathlib import Path
from typing import Optional

import click
from sqlalchemy import Connection
import sqlalchemy

from benchmark.tpch.load_info import TpchLoadInfo
from dbms.load_info_base_class import LoadInfoBaseClass
Expand All @@ -35,9 +36,9 @@
DEFAULT_POSTGRES_DBNAME,
DEFAULT_POSTGRES_PORT,
SHARED_PRELOAD_LIBRARIES,
conn_execute,
create_conn,
create_sqlalchemy_conn,
sql_file_execute,
sqlalchemy_conn_execute,
)
from util.shell import subprocess_run

Expand All @@ -47,7 +48,7 @@

@click.group(name="postgres")
@click.pass_obj
def postgres_group(dbgym_cfg: DBGymConfig):
def postgres_group(dbgym_cfg: DBGymConfig) -> None:
dbgym_cfg.append_group("postgres")


Expand All @@ -61,7 +62,7 @@ def postgres_group(dbgym_cfg: DBGymConfig):
is_flag=True,
help="Include this flag to rebuild Postgres even if it already exists.",
)
def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool):
def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool) -> None:
_build_repo(dbgym_cfg, rebuild)


Expand Down Expand Up @@ -94,14 +95,14 @@ def postgres_dbdata(
dbgym_cfg: DBGymConfig,
benchmark_name: str,
scale_factor: float,
pgbin_path: Path,
pgbin_path: Optional[Path],
intended_dbdata_hardware: str,
dbdata_parent_dpath: Path,
):
dbdata_parent_dpath: Optional[Path],
) -> None:
# Set args to defaults programmatically (do this before doing anything else in the function)
if pgbin_path == None:
if pgbin_path is None:
pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path)
if dbdata_parent_dpath == None:
if dbdata_parent_dpath is None:
dbdata_parent_dpath = default_dbdata_parent_dpath(
dbgym_cfg.dbgym_workspace_path
)
Expand Down Expand Up @@ -138,7 +139,7 @@ def _get_repo_symlink_path(dbgym_cfg: DBGymConfig) -> Path:
return dbgym_cfg.cur_symlinks_build_path("repo.link")


def _build_repo(dbgym_cfg: DBGymConfig, rebuild):
def _build_repo(dbgym_cfg: DBGymConfig, rebuild: bool) -> None:
expected_repo_symlink_dpath = _get_repo_symlink_path(dbgym_cfg)
if not rebuild and expected_repo_symlink_dpath.exists():
dbms_postgres_logger.info(
Expand Down Expand Up @@ -209,7 +210,7 @@ def _create_dbdata(
dbms_postgres_logger.info(f"Created dbdata in {dbdata_tgz_symlink_path}")


def _generic_dbdata_setup(dbgym_cfg: DBGymConfig):
def _generic_dbdata_setup(dbgym_cfg: DBGymConfig) -> None:
# get necessary vars
pgbin_real_dpath = _get_pgbin_symlink_path(dbgym_cfg).resolve()
assert pgbin_real_dpath.exists()
Expand Down Expand Up @@ -247,8 +248,8 @@ def _generic_dbdata_setup(dbgym_cfg: DBGymConfig):

def _load_benchmark_into_dbdata(
dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float
):
with create_conn(use_psycopg=False) as conn:
) -> None:
with create_sqlalchemy_conn() as conn:
if benchmark_name == "tpch":
load_info = TpchLoadInfo(dbgym_cfg, scale_factor)
else:
Expand All @@ -260,23 +261,27 @@ def _load_benchmark_into_dbdata(


def _load_into_dbdata(
dbgym_cfg: DBGymConfig, conn: Connection, load_info: LoadInfoBaseClass
):
dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, load_info: LoadInfoBaseClass
) -> None:
sql_file_execute(dbgym_cfg, conn, load_info.get_schema_fpath())

# truncate all tables first before even loading a single one
for table, _ in load_info.get_tables_and_fpaths():
conn_execute(conn, f"TRUNCATE {table} CASCADE")
sqlalchemy_conn_execute(conn, f"TRUNCATE {table} CASCADE")
# then, load the tables
for table, table_fpath in load_info.get_tables_and_fpaths():
with open_and_save(dbgym_cfg, table_fpath, "r") as table_csv:
with conn.connection.dbapi_connection.cursor() as cur:
assert conn.connection.dbapi_connection is not None
cur = conn.connection.dbapi_connection.cursor()
try:
with cur.copy(f"COPY {table} FROM STDIN CSV DELIMITER '|'") as copy:
while data := table_csv.read(8192):
copy.write(data)
finally:
cur.close()

constraints_fpath = load_info.get_constraints_fpath()
if constraints_fpath != None:
if constraints_fpath is not None:
sql_file_execute(dbgym_cfg, conn, constraints_fpath)


Expand Down
18 changes: 12 additions & 6 deletions dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
absl-py==2.1.0
aiosignal==1.3.1
astunparse==1.6.3
async-timeout==4.0.3
attrs==23.2.0
black==24.2.0
cachetools==5.3.2
Expand All @@ -25,7 +26,7 @@ google-auth-oauthlib==1.0.0
google-pasta==0.2.0
greenlet==3.0.3
grpcio==1.60.0
gymnasium==0.28.1
gymnasium==0.29.1
h5py==3.10.0
hyperopt==0.2.7
idna==3.6
Expand All @@ -44,6 +45,7 @@ MarkupSafe==2.1.4
ml-dtypes==0.2.0
mpmath==1.3.0
msgpack==1.0.7
mypy==1.11.2
mypy-extensions==1.0.0
networkx==3.2.1
numpy==1.26.3
Expand All @@ -56,7 +58,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu11==8.5.0.96
nvidia-cudnn-cu12==8.9.2.26
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu11==10.9.0.58
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu11==10.2.10.91
Expand All @@ -66,14 +68,15 @@ nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu11==11.7.4.91
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu11==2.14.3
nvidia-nccl-cu12==2.19.3
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu11==11.7.91
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.2
pandas==2.2.0
pandas-stubs==2.2.2.240807
pathspec==0.12.1
pglast==6.2
platformdirs==4.2.0
Expand All @@ -92,6 +95,7 @@ pytz==2023.4
PyYAML==6.0.1
ray==2.9.3
record-keeper==0.9.32
redis==5.0.3
referencing==0.33.0
requests==2.31.0
requests-oauthlib==1.3.1
Expand All @@ -112,14 +116,16 @@ tensorflow-io-gcs-filesystem==0.36.0
termcolor==2.4.0
threadpoolctl==3.2.0
tomli==2.0.1
torch==2.0.0
torch==2.4.0
tqdm==4.66.1
triton==2.0.0
triton==3.0.0
types-python-dateutil==2.9.0.20240821
types-pytz==2024.1.0.20240417
types-PyYAML==6.0.12.20240808
typing_extensions==4.9.0
tzdata==2023.4
urllib3==2.2.0
virtualenv==20.25.0
Werkzeug==3.0.1
wrapt==1.14.1
zipp==3.17.0
redis==5.0.3
Loading

0 comments on commit ac849f8

Please sign in to comment.