Skip to content

Commit

Permalink
add new changes
Browse files Browse the repository at this point in the history
  • Loading branch information
eegli committed Dec 27, 2024
1 parent 73af7c9 commit 5d0cf9b
Show file tree
Hide file tree
Showing 50 changed files with 2,569 additions and 1,048 deletions.
54 changes: 32 additions & 22 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ CASUAL_CONV_VERSION = 1.4.0
.DEFAULT_GOAL := all

.PHONY: all
all: format lint check_types test
all: format lint check_types test

.PHONY: .pre-commit
.pre-commit:
Expand All @@ -17,9 +17,9 @@ all: format lint check_types test
.PHONY: .install_common
.install_common:
@echo "Installing common Python dependencies"
uv sync --inexact --frozen
uv sync --all-extras --frozen

.PHONY: install_common_ci
.PHONY: install_common_ci
install_common_ci:
@echo "[CI] Installing common Python dependencies"
uv sync --inexact --frozen --no-cache --quiet
Expand All @@ -28,37 +28,37 @@ install_common_ci:
install_cpu: .install_common .pre-commit
@echo "Overriding/installing PyTorch (CPU)"
uv pip install --reinstall \
torch==${TORCH_VERSION}
"torch>=${TORCH_VERSION}"

.PHONY: install_cuda
install_cuda: .install_common .pre-commit
@echo "Overriding/installing PyTorch (GPU)"
uv pip install --reinstall \
torch==${TORCH_VERSION} --index-url ${CUDA_INDEX_URL}
"torch>=${TORCH_VERSION}" --index-url ${CUDA_INDEX_URL}

.PHONY: install_ci
install_ci: install_common_ci
@echo "[CI] Overriding/installing PyTorch (CPU)"
uv pip install --reinstall --no-cache --quiet \
torch==${TORCH_VERSION}
"torch>=${TORCH_VERSION}"

.PHONY: install_mamba
install_mamba:
uv pip install --no-build-isolation \
mamba-ssm==${MAMBA_VERSION} \
causal-conv1d==${CASUAL_CONV_VERSION}
"mamba-ssm>=${MAMBA_VERSION}" \
"causal-conv1d>=${CASUAL_CONV_VERSION}"

.PHONY: check_types
check_types:
uv run mypy src tests
uv run mypy src tests scripts

.PHONY: lint
lint:
uv run ruff check src tests
uv run ruff check src tests scripts

.PHONY: format
format:
uv run ruff format src tests
uv run ruff format src tests scripts

.PHONY: test
test: test_unit test_integration test_e2e
Expand All @@ -69,7 +69,17 @@ test_unit:

.PHONY: test_integration
test_integration:
uv run pytest tests/integration
$(MAKE) test_integration_install
$(MAKE) test_integration_config

.PHONY: test_integration_install
test_integration_install:
uv run --project tests/integration/install --reinstall-package mblm --quiet \
pytest tests/integration/install

.PHONY: test_integration_config
test_integration_config:
uv run pytest tests/integration/config

E2E_RUN_TORCH = OMP_NUM_THREADS=1 \
uv run torchrun --nproc_per_node=2 \
Expand All @@ -86,8 +96,8 @@ test_e2e:

.PHONY: test_e2e_grad_acc
test_e2e_grad_acc:
TEST_ID=grad_acc_1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-1.yml
TEST_ID=grad_acc_2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-2.yml
TEST_ID=grad_acc_1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-1.yaml
TEST_ID=grad_acc_2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-2.yaml
$(E2E_RUN_VALIDATE) \
--check-grad-acc-csv $(E2E_TEST_ROOT)/outputs/my-model_grad_acc_2/loss.csv \
--check-grad-acc-csv $(E2E_TEST_ROOT)/outputs/my-model_grad_acc_1/loss.csv
Expand All @@ -96,31 +106,31 @@ test_e2e_grad_acc:
test_e2e_trainer:
# ---------------- Chain 3 runs of 1 epoch each ----------------
@echo "Running training for a single epoch"
TEST_ID=1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1-epoch.yml
TEST_ID=1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1-epoch.yaml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_1

@echo "Chaining run 2 to run 1"
TEST_ID=2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_1/config.yml
TEST_ID=2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_1/config.yaml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_2

@echo "Chaining run 3 to run 2"
TEST_ID=3 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_2/config.yml
TEST_ID=3 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_2/config.yaml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_3

@echo "Asserting on chained training runs 1, 2, 3"
$(E2E_RUN_VALIDATE) \
--assert-equal-epochs \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_1/loss.csv \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_2/loss.csv \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_3/loss.csv
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_3/loss.csv

# ------------ Chain 2 runs of more than 1 epoch each ------------
@echo "Running training for more than 1 epoch"
TEST_ID=4 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1.5-epoch.yml
TEST_ID=4 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1.5-epoch.yaml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_4

@echo "Chaining run 5 to run 4"
TEST_ID=5 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_4/config.yml
TEST_ID=5 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_4/config.yaml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_5

@echo "Asserting on chained training runs 4, 5"
Expand All @@ -130,11 +140,11 @@ test_e2e_trainer:

# ------------ Chain 2 runs of less than 1 epoch each ------------
@echo "Running training for less than 1 epoch"
TEST_ID=6 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-0.5-epoch.yml
TEST_ID=6 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-0.5-epoch.yaml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_6

@echo "Chaining run 7 to run 6"
TEST_ID=7 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_6/config.yml
TEST_ID=7 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_6/config.yaml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_7

@echo "Asserting on chained training runs 6, 7"
Expand Down
Loading

0 comments on commit 5d0cf9b

Please sign in to comment.