Skip to content

Commit

Permalink
chore: add tests and trainer (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
eegli authored Dec 26, 2024
1 parent 191dc90 commit 73af7c9
Show file tree
Hide file tree
Showing 54 changed files with 5,165 additions and 99 deletions.
53 changes: 53 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: CI

on:
pull_request:
branches:
- main
push:
branches:
- main
jobs:
check:
name: Lint and check types
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: 'Set up Python'
uses: actions/setup-python@v5
with:
python-version-file: '.python-version'
- name: Lint
run: make lint
- name: Check formatting
run: uv run ruff format src tests --check
- name: Check types
run: make check_types

test:
name: Test Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
- '3.12'
steps:
- uses: actions/checkout@v4
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python
run: uv python install
- name: Install the project
run: uv python install
- name: Run unit tests
run: make test_unit
- name: Run integration tests
run: make test_integration
- name: Run e2e tests
run: make test_e2e
90 changes: 86 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,100 @@ install_mamba:
uv pip install --no-build-isolation \
mamba-ssm==${MAMBA_VERSION} \
causal-conv1d==${CASUAL_CONV_VERSION}

.PHONY: check_types
check_types:
uv run mypy src
uv run mypy src tests

.PHONY: lint
lint:
uv run ruff check src
uv run ruff check src tests

.PHONY: format
format:
uv run ruff format src
uv run ruff format src tests

.PHONY: test
test: test_unit test_integration test_e2e

.PHONY: test_unit
test_unit:
uv run pytest tests/unit

.PHONY: test_integration
test_integration:
uv run pytest tests/integration

E2E_RUN_TORCH = OMP_NUM_THREADS=1 \
uv run torchrun --nproc_per_node=2 \
tests/e2e/trainer/run_trainer.py
E2E_RUN_VALIDATE = uv run tests/e2e/trainer/validate.py
E2E_TEST_ROOT = tests/e2e/trainer

.PHONY: test_e2e
test_e2e:
@echo "Clearning test output dir $(E2E_TEST_ROOT)/outputs"
rm -rf $(E2E_TEST_ROOT)/outputs/*
$(MAKE) test_e2e_grad_acc
$(MAKE) test_e2e_trainer

.PHONY: test_e2e_grad_acc
test_e2e_grad_acc:
TEST_ID=grad_acc_1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-1.yml
TEST_ID=grad_acc_2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-2.yml
$(E2E_RUN_VALIDATE) \
--check-grad-acc-csv $(E2E_TEST_ROOT)/outputs/my-model_grad_acc_2/loss.csv \
--check-grad-acc-csv $(E2E_TEST_ROOT)/outputs/my-model_grad_acc_1/loss.csv

.PHONY: test_e2e_trainer
test_e2e_trainer:
# ---------------- Chain 3 runs of 1 epoch each ----------------
@echo "Running training for a single epoch"
TEST_ID=1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1-epoch.yml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_1

@echo "Chaining run 2 to run 1"
TEST_ID=2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_1/config.yml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_2

@echo "Chaining run 3 to run 2"
TEST_ID=3 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_2/config.yml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_3

@echo "Asserting on chained training runs 1, 2, 3"
$(E2E_RUN_VALIDATE) \
--assert-equal-epochs \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_1/loss.csv \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_2/loss.csv \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_3/loss.csv

# ------------ Chain 2 runs of more than 1 epoch each ------------
@echo "Running training for more than 1 epoch"
TEST_ID=4 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1.5-epoch.yml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_4

@echo "Chaining run 5 to run 4"
TEST_ID=5 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_4/config.yml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_5

@echo "Asserting on chained training runs 4, 5"
$(E2E_RUN_VALIDATE) \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_4/loss.csv \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_5/loss.csv

# ------------ Chain 2 runs of less than 1 epoch each ------------
@echo "Running training for less than 1 epoch"
TEST_ID=6 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-0.5-epoch.yml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_6

@echo "Chaining run 7 to run 6"
TEST_ID=7 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_6/config.yml
$(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_7

@echo "Asserting on chained training runs 6, 7"
$(E2E_RUN_VALIDATE) \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_6/loss.csv \
--check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_7/loss.csv

.PHONY: clear_cache
clear_cache:
Expand Down
54 changes: 54 additions & 0 deletions config/clevr_7mi_360m_1d_s_pt1_ft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
io:
name_model: 8k_7mi_360m_1d_s_pt1_ft
output_dir: /data/output < set this!
dataset_dir: /data/datasets/clevr < set this!
dataset_id: clevr
dataset_args:
target_mode: a
qiqa_loss_mask: [0.0, 0.0, 0.0, 1]
answer_categorical: true
resize_to_w: 62
resize_to_h: 41
crop_h_perc: 0.1
crop_w_perc: 0.1
eom_token_id: 129
som_text_token_id: 130
som_image_token_id: 131
downsample_channels: null
shift_channels_start: null
num_models_to_save: 5
validate_amount: 100
log_train_loss_amount: 1000
description: >-
Describe this!
params:
num_tokens: 256
pad_token_id: 128
input_seq_len: 8192
seq_lens: [8192]
hidden_dims: [1024]
num_layers: [54]
train_checkpoint_chunks: null
block:
d_state: 128
d_conv: 4
expand: 2
headdim: 64
dropout: 0.1
patch_pos_emb_type: null
train:
target_elements: 7_000_000
target_elements_strategy: batch
batch_size: 6
max_eval_steps: 1000
shuffle_train: true
learning_rate: 0.0001
gradient_clipping: 0.5
gradient_accumulate_every: 48
resume:
checkpoint_file: /path-to-checkpoint.pth
next_batch_index: 0
next_epoch_index: 0
migrate_embeddings: false
rename_modules: true
resumed_from: null
31 changes: 31 additions & 0 deletions config/pg19_30bb_360m_1d_t.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
io:
name_model: 8k_30b_360m_1d_t
output_dir: /data/output < set this!
dataset_dir: /data/datasets/pg19 < set this!
dataset_id: pg19
num_models_to_save: 5
validate_amount: 100
log_train_loss_amount: 1000
params:
num_tokens: 256
pad_token_id: 0
input_seq_len: 8192
seq_lens: [8192]
hidden_dims: [1024]
num_layers: [42]
train_checkpoint_chunks: null
block:
attn_head_dims: 64
attn_num_heads: 16
attn_use_rot_embs: true
attn_dropout: 0
use_flash_attn: true
patch_pos_emb_type: fixed
train:
target_elements: 30_000_000_000
target_elements_strategy: sequence
batch_size: 4
shuffle_train: false
learning_rate: 0.001
gradient_clipping: 1
gradient_accumulate_every: 12
30 changes: 30 additions & 0 deletions config/pg19_30bb_360m_2d_ss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
io:
name_model: 8k_30b_360m_2d_ss
output_dir: /data/output < set this!
dataset_dir: /data/datasets/pg19 < set this!
dataset_id: pg19
num_models_to_save: 5
validate_amount: 100
log_train_loss_amount: 1000
params:
num_tokens: 256
pad_token_id: 0
input_seq_len: 8192
seq_lens: [1024, 8]
hidden_dims: [1024, 1024]
num_layers: [28, 24]
train_checkpoint_chunks: null
block:
d_state: 128
d_conv: 4
expand: 2
headdim: 64
patch_pos_emb_type: null
train:
target_elements: 30_000_000_000
target_elements_strategy: sequence
batch_size: 6
shuffle_train: false
learning_rate: 0.001
gradient_clipping: 1
gradient_accumulate_every: 8
36 changes: 36 additions & 0 deletions config/pg19_30bb_360m_2d_st.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
io:
name_model: 8k_30b_360m_2d_st
output_dir: /data/output < set this!
dataset_dir: /data/datasets/pg19 < set this!
dataset_id: pg19
num_models_to_save: 5
validate_amount: 100
log_train_loss_amount: 1000
params:
num_tokens: 256
pad_token_id: 0
input_seq_len: 8192
seq_lens: [1024, 8]
hidden_dims: [1024, 1024]
num_layers: [25, 21]
train_checkpoint_chunks: null
block:
- d_state: 128
d_conv: 4
expand: 2
headdim: 64
patch_pos_emb_type: null
- attn_head_dims: 64
attn_num_heads: 16
attn_use_rot_embs: true
attn_dropout: 0
use_flash_attn: true
patch_pos_emb_type: fixed
train:
target_elements: 30_000_000_000
target_elements_strategy: sequence
batch_size: 6
shuffle_train: false
learning_rate: 0.001
gradient_clipping: 1
gradient_accumulate_every: 8
5 changes: 5 additions & 0 deletions config/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
This folder contains example experiment configurations in yaml format that can be passed to torchrun via:

```sh
bash src/mblm/scripts/train_launch.sh <config.yaml>
```
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,27 @@ dev-dependencies = [
"pre-commit>=3.8.0",
"pytest-mock>=3.14.0",
"types-tabulate>=0.9.0.20240106",
"rouge-score>=0.1.2",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.sdist]
include = ["src/**/*", "mit.tmpl"]

[tool.mypy]
check_untyped_defs = true

[[tool.mypy.overrides]]
module = ["tqdm.*", "mambapy.*", "mamba_ssm.*", "MEGABYTE_pytorch.*"]
module = [
"tqdm.*",
"mambapy.*",
"mamba_ssm.*",
"MEGABYTE_pytorch.*",
"rouge_score.*",
]
ignore_missing_imports = true

[[tool.mypy.overrides]]
Expand Down
4 changes: 2 additions & 2 deletions src/mblm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__version__ = "0.0.1"


from mblm.model.config import MBLMModelConfig
from mblm.model.config import MBLMModelConfig, MBLMReturnType
from mblm.model.mblm import MBLM

__all__ = ["MBLM", "MBLMModelConfig"]
__all__ = ["MBLM", "MBLMModelConfig", "MBLMReturnType"]
Loading

0 comments on commit 73af7c9

Please sign in to comment.