Skip to content

Commit

Permalink
Merge pull request #11 from DubiousCactus/builder
Browse files Browse the repository at this point in the history
Builder
  • Loading branch information
DubiousCactus authored Oct 23, 2024
2 parents 3ff1d2a + 8da91bf commit 88b6b85
Show file tree
Hide file tree
Showing 21 changed files with 973 additions and 301 deletions.
31 changes: 31 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"configurations": [
{
"name": "Launch Train [exp_a]",
"type": "python",
"request": "launch",
"python": "/Users/cactus/miniforge3/envs/bellsnw/bin/python",
"autoReload": { "enable": true },
"program": "${workspaceFolder}/train.py",
"args": ["+experiment=exp_a", "dataset.tiny=1"]
},
{
"name": "Launch Build [exp_a]",
"type": "python",
"request": "launch",
"python": "/Users/cactus/miniforge3/envs/bellsnw/bin/python",
"autoReload": { "enable": true },
"program": "${workspaceFolder}/build.py",
"args": ["+experiment=exp_a", "dataset.tiny=1"]
},
{
"name": "Attach Build [exp_a]",
"type": "python",
"request": "attach",
"connect": {
"host": "localhost",
"port": 5555
}
}
]
}
38 changes: 38 additions & 0 deletions bootstrap/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional

from hydra_zen.typing import Partial


class MatchboxModule:
PREV = "MatchboxModule.PREV" # TODO: This is used as an enum value. Should figure it out

def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs):
# TODO: Figure out this entire class. It's a hack, I'm still figuring things
# out as I go.
self._str_rep = name
self.underlying_fn = fn.func if isinstance(fn, partial) else fn
self.partial = partial(fn, *args, **kwargs)

def __call__(self, prev_result: Any) -> Any:
# TODO: Replace .PREV in any of the function's args/kwargs with prev_result
for i, arg in enumerate(self.partial.args):
if arg == self.PREV:
assert prev_result is not None
self.partial.args[i] = prev_result
for key, value in self.partial.keywords.items():
if value == self.PREV:
assert prev_result is not None
self.partial.keywords[key] = prev_result
return self.partial()

def __str__(self) -> str:
return self._str_rep


@dataclass
class MatchboxModuleState:
first_run: bool
result: Any
is_frozen: bool
141 changes: 141 additions & 0 deletions bootstrap/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from rich.syntax import Syntax
from torch.utils.data import DataLoader, Dataset

from bootstrap import MatchboxModule
from bootstrap.factories import (
make_dataloaders,
make_datasets,
Expand All @@ -34,6 +35,7 @@
make_training_loss,
parallelize_model,
)
from bootstrap.tui.builder_ui import BuilderUI
from bootstrap.tui.training_ui import TrainingUI
from conf import project as project_conf
from src.base_tester import BaseTester
Expand Down Expand Up @@ -105,6 +107,145 @@ def init_wandb(
wandb.watch(model, log=log, log_graph=log_graph) # type: ignore


def launch_builder(
run, # type: ignore
data_loader: Partial[DataLoader[Any]],
optimizer: Partial[torch.optim.Optimizer], # pyright: ignore
scheduler: Partial[torch.optim.lr_scheduler.LRScheduler],
trainer: Partial[BaseTrainer],
tester: Partial[BaseTester],
dataset: Partial[Dataset[Any]],
model: Partial[torch.nn.Module],
training_loss: Partial[torch.nn.Module],
):
exp_conf = hydra_zen.to_yaml(
dict(
run_conf=run,
dataset=dataset,
model=model,
optimizer=optimizer,
scheduler=scheduler,
training_loss=training_loss,
)
)
# TODO: Overwrite data_loader.num_workers=0
# data_loader.num_workers = 0

async def launch_with_async_gui():
tui = BuilderUI()
task = asyncio.create_task(tui.run_async())
await asyncio.sleep(0.5) # Wait for the app to start up
while not tui.is_running:
await asyncio.sleep(0.01) # Wait for the app to start up
# trace_catcher = TraceCatcher(tui)

# ============ Partials instantiation ============
# NOTE: We're gonna need a lot of thinking and right now I'm just too tired. We
# basically need to have a complex mechanism that does conditional hot code
# reloading in the following places. Of course, we'll never re-run the entire
# program while in the builder. We'll just reload pieces of code and restart the
# execution at some specific places.

# train_dataset = await trace_catcher.catch_and_hang(
# dataset, split="train", seed=run.seed, progress=None, job_id=None
# )
# model_inst = await trace_catcher.catch_and_hang(
# make_model, model, train_dataset
# )
# opt_inst = await trace_catcher.catch_and_hang(
# make_optimizer, optimizer, model_inst
# )
# scheduler_inst = await trace_catcher.catch_and_hang(
# make_scheduler, scheduler, opt_inst, run.epochs
# )
# training_loss_inst = await trace_catcher.catch_and_hang(
# make_training_loss, run.training_mode, training_loss
# )
# if model_inst is not None:
# model_inst = to_cuda_(parallelize_model(model_inst))
# if training_loss_inst is not None:
# training_loss_inst = to_cuda_(training_loss_inst)
tui.chain_up(
[
MatchboxModule(
"Dataset",
dataset, # TODO: Fix the code reloading, then revert to using the dataset factory
split="train",
seed=run.seed,
progress=None,
job_id=None,
),
MatchboxModule(
"Model",
make_model,
model,
dataset=dataset,
),
MatchboxModule(
"Optimizer", make_optimizer, optimizer, model=MatchboxModule.PREV
),
MatchboxModule(
"Scheduler",
make_scheduler,
scheduler,
optimizer=MatchboxModule.PREV,
epochs=run.epochs,
),
MatchboxModule(
"Loss", make_training_loss, run.training_mode, training_loss
),
]
)
tui.run_chain()
# all_success = False # TODO:
# if all_success:
# # TODO: idk how to handle this YET
# # Somehow, the dataloader will crash if it's not forked when using multiprocessing
# # along with Textual.
# mp.set_start_method("fork")
# train_loader_inst, val_loader_inst, test_loader_inst = make_dataloaders(
# data_loader,
# train_dataset,
# val_dataset,
# test_dataset,
# run.training_mode,
# run.seed,
# )
# init_wandb("test-run", model_inst, exp_conf)
#
# model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode)
# common_args = dict(
# run_name="build-run",
# model=model_inst,
# model_ckpt_path=model_ckpt_path,
# training_loss=training_loss_inst,
# tui=tui,
# )
# if training_loss_inst is None:
# raise ValueError("training_loss must be defined in training mode!")
# if val_loader_inst is None or train_loader_inst is None:
# raise ValueError(
# "val_loader and train_loader must be defined in training mode!"
# )
# await trainer(
# train_loader=train_loader_inst,
# val_loader=val_loader_inst,
# opt=opt_inst,
# scheduler=scheduler_inst,
# **common_args,
# **asdict(run),
# ).train(
# epochs=run.epochs,
# val_every=run.val_every,
# visualize_every=run.viz_every,
# visualize_train_every=run.viz_train_every,
# visualize_n_samples=run.viz_num_samples,
# )
_ = await task

asyncio.run(launch_with_async_gui())


def launch_experiment(
run, # type: ignore
data_loader: Partial[DataLoader[Any]],
Expand Down
Loading

0 comments on commit 88b6b85

Please sign in to comment.