diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9e9b1cb --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,53 @@ +name: CI + +on: + pull_request: + branches: + - main + push: + branches: + - main +jobs: + check: + name: Lint and check types + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v4 + - name: 'Set up Python' + uses: actions/setup-python@v5 + with: + python-version-file: '.python-version' + - name: Lint + run: make lint + - name: Check formatting + run: uv run ruff format src tests --check + - name: Check types + run: make check_types + + test: + name: Test Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - '3.10' + - '3.11' + - '3.12' + steps: + - uses: actions/checkout@v4 + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Set up Python + run: uv python install + - name: Install the project + run: uv python install + - name: Run unit tests + run: make test_unit + - name: Run integration tests + run: make test_integration + - name: Run e2e tests + run: make test_e2e diff --git a/Makefile b/Makefile index ddb571d..5f06b17 100644 --- a/Makefile +++ b/Makefile @@ -47,18 +47,100 @@ install_mamba: uv pip install --no-build-isolation \ mamba-ssm==${MAMBA_VERSION} \ causal-conv1d==${CASUAL_CONV_VERSION} - + .PHONY: check_types check_types: - uv run mypy src + uv run mypy src tests .PHONY: lint lint: - uv run ruff check src + uv run ruff check src tests .PHONY: format format: - uv run ruff format src + uv run ruff format src tests + +.PHONY: test +test: test_unit test_integration test_e2e + +.PHONY: test_unit +test_unit: + uv run pytest tests/unit + +.PHONY: test_integration +test_integration: + uv run pytest tests/integration + +E2E_RUN_TORCH = OMP_NUM_THREADS=1 \ + uv run torchrun --nproc_per_node=2 \ + tests/e2e/trainer/run_trainer.py +E2E_RUN_VALIDATE = uv run tests/e2e/trainer/validate.py +E2E_TEST_ROOT = tests/e2e/trainer + +.PHONY: test_e2e +test_e2e: + @echo "Clearning test output dir $(E2E_TEST_ROOT)/outputs" + rm -rf $(E2E_TEST_ROOT)/outputs/* + $(MAKE) test_e2e_grad_acc + $(MAKE) test_e2e_trainer + +.PHONY: test_e2e_grad_acc +test_e2e_grad_acc: + TEST_ID=grad_acc_1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-1.yml + TEST_ID=grad_acc_2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-grad-acc-2.yml + $(E2E_RUN_VALIDATE) \ + --check-grad-acc-csv $(E2E_TEST_ROOT)/outputs/my-model_grad_acc_2/loss.csv \ + --check-grad-acc-csv $(E2E_TEST_ROOT)/outputs/my-model_grad_acc_1/loss.csv + +.PHONY: test_e2e_trainer +test_e2e_trainer: + # ---------------- Chain 3 runs of 1 epoch each ---------------- + @echo "Running training for a single epoch" + TEST_ID=1 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1-epoch.yml + $(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_1 + + @echo "Chaining run 2 to run 1" + TEST_ID=2 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_1/config.yml + $(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_2 + + @echo "Chaining run 3 to run 2" + TEST_ID=3 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_2/config.yml + $(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_3 + + @echo "Asserting on chained training runs 1, 2, 3" + $(E2E_RUN_VALIDATE) \ + --assert-equal-epochs \ + --check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_1/loss.csv \ + --check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_2/loss.csv \ + --check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_3/loss.csv + + # ------------ Chain 2 runs of more than 1 epoch each ------------ + @echo "Running training for more than 1 epoch" + TEST_ID=4 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-1.5-epoch.yml + $(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_4 + + @echo "Chaining run 5 to run 4" + TEST_ID=5 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_4/config.yml + $(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_5 + + @echo "Asserting on chained training runs 4, 5" + $(E2E_RUN_VALIDATE) \ + --check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_4/loss.csv \ + --check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_5/loss.csv + + # ------------ Chain 2 runs of less than 1 epoch each ------------ + @echo "Running training for less than 1 epoch" + TEST_ID=6 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/sample-config-0.5-epoch.yml + $(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_6 + + @echo "Chaining run 7 to run 6" + TEST_ID=7 $(E2E_RUN_TORCH) -c $(E2E_TEST_ROOT)/outputs/my-model_6/config.yml + $(E2E_RUN_VALIDATE) --check-output $(E2E_TEST_ROOT)/outputs/my-model_7 + + @echo "Asserting on chained training runs 6, 7" + $(E2E_RUN_VALIDATE) \ + --check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_6/loss.csv \ + --check-chained-csv $(E2E_TEST_ROOT)/outputs/my-model_7/loss.csv .PHONY: clear_cache clear_cache: diff --git a/config/clevr_7mi_360m_1d_s_pt1_ft.yaml b/config/clevr_7mi_360m_1d_s_pt1_ft.yaml new file mode 100644 index 0000000..a896ded --- /dev/null +++ b/config/clevr_7mi_360m_1d_s_pt1_ft.yaml @@ -0,0 +1,54 @@ +io: + name_model: 8k_7mi_360m_1d_s_pt1_ft + output_dir: /data/output < set this! + dataset_dir: /data/datasets/clevr < set this! + dataset_id: clevr + dataset_args: + target_mode: a + qiqa_loss_mask: [0.0, 0.0, 0.0, 1] + answer_categorical: true + resize_to_w: 62 + resize_to_h: 41 + crop_h_perc: 0.1 + crop_w_perc: 0.1 + eom_token_id: 129 + som_text_token_id: 130 + som_image_token_id: 131 + downsample_channels: null + shift_channels_start: null + num_models_to_save: 5 + validate_amount: 100 + log_train_loss_amount: 1000 + description: >- + Describe this! +params: + num_tokens: 256 + pad_token_id: 128 + input_seq_len: 8192 + seq_lens: [8192] + hidden_dims: [1024] + num_layers: [54] + train_checkpoint_chunks: null + block: + d_state: 128 + d_conv: 4 + expand: 2 + headdim: 64 + dropout: 0.1 + patch_pos_emb_type: null +train: + target_elements: 7_000_000 + target_elements_strategy: batch + batch_size: 6 + max_eval_steps: 1000 + shuffle_train: true + learning_rate: 0.0001 + gradient_clipping: 0.5 + gradient_accumulate_every: 48 +resume: + checkpoint_file: /path-to-checkpoint.pth + next_batch_index: 0 + next_epoch_index: 0 + migrate_embeddings: false + rename_modules: true + resumed_from: null diff --git a/config/pg19_30bb_360m_1d_t.yaml b/config/pg19_30bb_360m_1d_t.yaml new file mode 100644 index 0000000..fdb92d2 --- /dev/null +++ b/config/pg19_30bb_360m_1d_t.yaml @@ -0,0 +1,31 @@ +io: + name_model: 8k_30b_360m_1d_t + output_dir: /data/output < set this! + dataset_dir: /data/datasets/pg19 < set this! + dataset_id: pg19 + num_models_to_save: 5 + validate_amount: 100 + log_train_loss_amount: 1000 +params: + num_tokens: 256 + pad_token_id: 0 + input_seq_len: 8192 + seq_lens: [8192] + hidden_dims: [1024] + num_layers: [42] + train_checkpoint_chunks: null + block: + attn_head_dims: 64 + attn_num_heads: 16 + attn_use_rot_embs: true + attn_dropout: 0 + use_flash_attn: true + patch_pos_emb_type: fixed +train: + target_elements: 30_000_000_000 + target_elements_strategy: sequence + batch_size: 4 + shuffle_train: false + learning_rate: 0.001 + gradient_clipping: 1 + gradient_accumulate_every: 12 diff --git a/config/pg19_30bb_360m_2d_ss.yaml b/config/pg19_30bb_360m_2d_ss.yaml new file mode 100644 index 0000000..075151f --- /dev/null +++ b/config/pg19_30bb_360m_2d_ss.yaml @@ -0,0 +1,30 @@ +io: + name_model: 8k_30b_360m_2d_ss + output_dir: /data/output < set this! + dataset_dir: /data/datasets/pg19 < set this! + dataset_id: pg19 + num_models_to_save: 5 + validate_amount: 100 + log_train_loss_amount: 1000 +params: + num_tokens: 256 + pad_token_id: 0 + input_seq_len: 8192 + seq_lens: [1024, 8] + hidden_dims: [1024, 1024] + num_layers: [28, 24] + train_checkpoint_chunks: null + block: + d_state: 128 + d_conv: 4 + expand: 2 + headdim: 64 + patch_pos_emb_type: null +train: + target_elements: 30_000_000_000 + target_elements_strategy: sequence + batch_size: 6 + shuffle_train: false + learning_rate: 0.001 + gradient_clipping: 1 + gradient_accumulate_every: 8 diff --git a/config/pg19_30bb_360m_2d_st.yaml b/config/pg19_30bb_360m_2d_st.yaml new file mode 100644 index 0000000..4fde520 --- /dev/null +++ b/config/pg19_30bb_360m_2d_st.yaml @@ -0,0 +1,36 @@ +io: + name_model: 8k_30b_360m_2d_st + output_dir: /data/output < set this! + dataset_dir: /data/datasets/pg19 < set this! + dataset_id: pg19 + num_models_to_save: 5 + validate_amount: 100 + log_train_loss_amount: 1000 +params: + num_tokens: 256 + pad_token_id: 0 + input_seq_len: 8192 + seq_lens: [1024, 8] + hidden_dims: [1024, 1024] + num_layers: [25, 21] + train_checkpoint_chunks: null + block: + - d_state: 128 + d_conv: 4 + expand: 2 + headdim: 64 + patch_pos_emb_type: null + - attn_head_dims: 64 + attn_num_heads: 16 + attn_use_rot_embs: true + attn_dropout: 0 + use_flash_attn: true + patch_pos_emb_type: fixed +train: + target_elements: 30_000_000_000 + target_elements_strategy: sequence + batch_size: 6 + shuffle_train: false + learning_rate: 0.001 + gradient_clipping: 1 + gradient_accumulate_every: 8 diff --git a/config/readme.md b/config/readme.md new file mode 100644 index 0000000..dfae807 --- /dev/null +++ b/config/readme.md @@ -0,0 +1,5 @@ +This folder contains example experiment configurations in yaml format that can be passed to torchrun via: + +```sh +bash src/mblm/scripts/train_launch.sh +``` diff --git a/pyproject.toml b/pyproject.toml index c302cb2..736e47a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,17 +37,27 @@ dev-dependencies = [ "pre-commit>=3.8.0", "pytest-mock>=3.14.0", "types-tabulate>=0.9.0.20240106", + "rouge-score>=0.1.2", ] [build-system] requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.build.targets.sdist] +include = ["src/**/*", "mit.tmpl"] + [tool.mypy] check_untyped_defs = true [[tool.mypy.overrides]] -module = ["tqdm.*", "mambapy.*", "mamba_ssm.*", "MEGABYTE_pytorch.*"] +module = [ + "tqdm.*", + "mambapy.*", + "mamba_ssm.*", + "MEGABYTE_pytorch.*", + "rouge_score.*", +] ignore_missing_imports = true [[tool.mypy.overrides]] diff --git a/src/mblm/__init__.py b/src/mblm/__init__.py index 771cf79..a470b91 100644 --- a/src/mblm/__init__.py +++ b/src/mblm/__init__.py @@ -10,7 +10,7 @@ __version__ = "0.0.1" -from mblm.model.config import MBLMModelConfig +from mblm.model.config import MBLMModelConfig, MBLMReturnType from mblm.model.mblm import MBLM -__all__ = ["MBLM", "MBLMModelConfig"] +__all__ = ["MBLM", "MBLMModelConfig", "MBLMReturnType"] diff --git a/src/mblm/analysis/aggregate.py b/src/mblm/analysis/aggregate.py new file mode 100644 index 0000000..25850f7 --- /dev/null +++ b/src/mblm/analysis/aggregate.py @@ -0,0 +1,292 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +import math +from dataclasses import dataclass +from datetime import datetime +from functools import reduce +from pathlib import Path +from typing import Iterable, Literal, get_args + +import polars as pl +from pydantic import ValidationError + +from mblm.data.types import ModelMode +from mblm.model.block import BlockType +from mblm.scripts.train_mblm import TrainOutputConfig +from mblm.utils.io import load_yml + +ModelType = Literal["SSM", "Transformer", "Mixed"] + +# weirdly enough, we cannot use this enum in the schema definition directly and +# need to cast after the df has been created +_mode_enum = pl.Enum([ModelMode.TRAIN.value, ModelMode.VALID.value, ModelMode.TEST.value]) +_model_type_enum = pl.Enum(get_args(ModelType)) + +_gpu_df_default_schema = dict( + cum_batch=pl.Int32, + num_items=pl.Int16, + kind=pl.String, # renamed to mode + fw_time=pl.Float32, + bw_time=pl.Float32, + allocated=pl.Float32, + allocated_max=pl.Float32, + reserved=pl.Float32, + reserved_max=pl.Float32, + total=pl.Float32, +) +# we have added new fields since, check Aggregator +_loss_df_default_schema = dict( + timestamp=pl.Datetime, + elements_seen=pl.Int64, + kind=pl.String, # renamed to mode + epoch=pl.Int64, + batch=pl.Int64, + cum_batch=pl.Int64, + loss=pl.Float64, +) +_bpb_scale_factor = math.log2(math.e) + + +@dataclass +class Filter: + mode: ModelMode | None = None + num_stages_gte: int | None = None + num_stages_lte: int | None = None + + +class Aggregator: + def __init__(self, folders: Iterable[tuple[str, str]]): + self.configs, self.df_exp, self._df_train, df_gpu = self._read_experiments(folders) + self.exp_names = [folder_name[1] for folder_name in folders] + + # drop the very first train batch measurement, it includes the warmup times + self._df_gpu = df_gpu.filter( + ~((pl.col("mode") == "train") & (pl.col("cum_batch") == 0)), + ) + + def df_train(self, filter: Filter = Filter()) -> pl.DataFrame: + return self._filter_df(self._df_train, filter).pipe(self._rename_values_for_plot) + + def df_gpu(self, filter: Filter = Filter()) -> pl.DataFrame: + return self._filter_df(self._df_gpu, filter).pipe(self._rename_values_for_plot) + + def _filter_df(self, df_with_size: pl.DataFrame, filter: Filter) -> pl.DataFrame: + f = filter + return ( + df_with_size.lazy() + .filter((pl.col("mode").eq(f.mode)) if f.mode else True) + .filter((pl.col("stages").ge(f.num_stages_gte)) if f.num_stages_gte else True) + .filter((pl.col("stages").le(f.num_stages_lte)) if f.num_stages_lte else True) + .collect() + ) + + def _rename_values_for_plot(self, df: pl.DataFrame) -> pl.DataFrame: + return df.with_columns( + pl.col("kind").replace(["train", "valid", "test"], ["Training", "Validation", "Test"]), + pl.col("stages") + .cast(pl.String) + .replace(["1", "2", "3"], ["1D, 8k ctx", "2D, 100k ctx", "3D, 1m ctx"]), + ) + + def _df_with_common_props( + self, + df: pl.DataFrame, + config: TrainOutputConfig, + exp_name: str, + model_type: str, + ): + return df.with_columns( + mode=pl.col("kind").cast(_mode_enum), + name=pl.lit(exp_name), + model_type=pl.lit(model_type).cast(_model_type_enum), + stages=pl.lit(len(config.params.hidden_dims)).cast(pl.Int8), + seq_len=pl.lit(config.params.input_seq_len), + ) + + def _determine_model_type(self, config: TrainOutputConfig) -> ModelType: + block_types: list[BlockType] = [block.block_type for block in config.params.blocks()] + if all([block_type == "mamba2" for block_type in block_types]): + return "SSM" + if all([block_type == "transformer" for block_type in block_types]): + return "Transformer" + return "Mixed" + + def _read_experiments( + self, folders: Iterable[tuple[str, str]] + ) -> tuple[list[TrainOutputConfig], pl.DataFrame, pl.DataFrame, pl.DataFrame]: + dfs_exp: list[pl.DataFrame] = [] + dfs_train: list[pl.DataFrame] = [] + dfs_gpu: list[pl.DataFrame] = [] + configs: list[TrainOutputConfig] = [] + for folder, exp_name in folders: + path = Path(folder) + try: + config = load_yml(path / "config.yml", parse_to=TrainOutputConfig) + + grad_acc_every = config.train.gradient_accumulate_every + model_type = self._determine_model_type(config) + + # static df with experiment/model details + df_exp = pl.DataFrame().with_columns( + name=pl.lit(exp_name), + model_type=pl.lit(model_type).cast(_model_type_enum), + ctx_size_total=pl.lit(config.params.input_seq_len), + ctx_sizes=pl.lit(config.params.seq_lens), + elements_trained=pl.lit(config.train.target_elements), + params_m=pl.lit(config.summary.parameter_count).truediv(1e6).round(0), + num_layers=pl.lit(config.params.num_layers), + num_tokens=pl.lit(config.params.num_tokens), + checkpoints=pl.lit(config.params.train_checkpoint_chunks), + patch_pos_emb=pl.lit( + [b.patch_pos_emb_type or "none" for b in config.params.blocks()] + ), + lr=pl.lit(config.train.learning_rate), + batch_size=pl.lit(config.train.batch_size), + grad_step=pl.lit(grad_acc_every * config.train.batch_size), + grad_clip=pl.lit(config.train.gradient_clipping), + num_gpus=pl.lit(config.summary.num_workers), + dataset_args=pl.lit(config.io.dataset_args, allow_object=True), + comment=pl.lit(config.io.description), + training_finished=pl.lit(bool(config.summary.training_end)), + error=pl.lit(config.summary.error), + ) + + dfs_exp.append(df_exp) + + df_train = ( + self._try_read_csv(path / "loss.csv", _loss_df_default_schema, True) + .pipe( + self._df_with_common_props, + config=config, + exp_name=exp_name, + model_type=model_type, + ) + .with_columns( + pl.col("elements_seen") * config.summary.num_workers, # TODO check + bpb=pl.col("loss") * _bpb_scale_factor, + cum_batch=(pl.col("cum_batch") + 1), + ) + .with_columns(step=((pl.col("cum_batch")) // grad_acc_every).cast(pl.Int32)) + ) + dfs_train.append(df_train) + + df_gpu = ( + self._try_read_csv(path / "timemem.csv", _gpu_df_default_schema) + .pipe( + self._df_with_common_props, + config=config, + exp_name=exp_name, + model_type=model_type, + ) + .rename({"num_items": "batch_size"}) + ) + dfs_gpu.append(df_gpu) + except ValidationError as e: + print(folder) + raise e + + configs.append(config) + df_exp = reduce(lambda new_df, df: df.vstack(new_df), dfs_exp) + df_train = reduce(lambda new_df, df: df.vstack(new_df), dfs_train) + df_gpu = reduce(lambda new_df, df: df.vstack(new_df), dfs_gpu) + return configs, df_exp, df_train, df_gpu + + def _try_read_csv(self, csv_path: str | Path, schema: dict, is_loss_csv: bool = False): + try: + df = pl.read_csv( + csv_path, + schema_overrides=schema, + raise_if_empty=False, + null_values="nan", + ) + except Exception: + df = pl.DataFrame(schema=schema) + + # we added these three fields later in the project, hence, they don't exist + # for all existing csv files - interpolate + if is_loss_csv: + if "lr" not in df.columns: + df = df.with_columns( + lr=pl.lit(-1, pl.Float64), + avg_grad=pl.lit(-1, pl.Float64), + avg_grad_clipped=pl.lit(-1, pl.Float64), + ) + if "gpu_rank" not in df.columns: + df = df.with_columns(gpu_rank=pl.lit(0, pl.Int16)) + + df = df.with_columns( + pl.col("avg_grad", "avg_grad_clipped").cast(pl.Float64), + pl.col("gpu_rank").cast(pl.Int16), + ) + + return df + + def _convert_utc_timezones( + self, + df: pl.DataFrame, + columns: list[str], + to_timezone: str, + ) -> pl.DataFrame: + utc_fmt = "%Y-%m-%dT%H:%M:%S%.f" + exprs = [ + pl.col(col) + .str.to_datetime(format=utc_fmt, time_zone="UTC", time_unit="ms") + .dt.replace_time_zone(to_timezone) + for col in columns + ] + return df.with_columns(*exprs) + + def df_train_times(self) -> pl.DataFrame: + df = pl.DataFrame( + [ + [ + name, + cfg.summary.training_start, + cfg.summary.training_end or datetime.now().isoformat(), + bool(cfg.summary.training_end), + ] + for cfg, name in zip(self.configs, self.exp_names) + ], + orient="row", + schema=["name", "train_start", "train_end", "has_finished"], + ) + return ( + df.pipe( + self._convert_utc_timezones, + columns=["train_start", "train_end"], + to_timezone="Europe/Brussels", + ) + .with_columns( + train_dur=(pl.col("train_end") - pl.col("train_start")), + ) + .with_columns( + train_dur_h=pl.col("train_dur").dt.total_hours(), + train_dur_str=pl.col("train_dur").dt.total_days().cast(pl.String) + + "d " + + (pl.col("train_dur").dt.total_hours() % 24).cast(pl.String) + + "h " + + (pl.col("train_dur").dt.total_minutes() % 60).cast(pl.String) + + "m", + ) + .sort(pl.col("train_dur")) + ) diff --git a/src/mblm/analysis/metrics.py b/src/mblm/analysis/metrics.py new file mode 100644 index 0000000..9798692 --- /dev/null +++ b/src/mblm/analysis/metrics.py @@ -0,0 +1,58 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +from typing import Literal, Sequence, TypeVar + +from rouge_score import rouge_scorer, scoring + +from mblm.data.utils.bytes import Bytes + +T = TypeVar("T") + + +def token_accuracy(a: Sequence[T], b: Sequence[T], allow_var_lens: bool) -> float: + if len(a) != len(b) and not allow_var_lens: + raise ValueError("Sequences must have the same length") + if len(a) == 0 or len(b) == 0: + # opinionated edge case + return 0 + correct = 0 + for item_a, item_b in zip(a, b): + correct += int(item_a == item_b) + return correct / len(a) + + +rouge_scorers = dict( + rouge1=rouge_scorer.RougeScorer(["rouge1"]), rougeL=rouge_scorer.RougeScorer(["rougeL"]) +) + + +def rouge_score_from_bytes( + target: list[int], predicted: list[int], which: Literal["rouge1", "rougeL"] +) -> float: + try: + t = Bytes.byte_list_to_str(target) + p = Bytes.byte_list_to_str(predicted) + score: scoring.Score = rouge_scorers[which].score(t, p)[which] + return score.fmeasure + except Exception: + return -1 diff --git a/src/mblm/analysis/utils.py b/src/mblm/analysis/utils.py new file mode 100644 index 0000000..bc61c61 --- /dev/null +++ b/src/mblm/analysis/utils.py @@ -0,0 +1,77 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +import re +from pathlib import Path +from typing import TypeAlias + +from mblm import MBLM, MBLMModelConfig +from mblm.data.utils import Bytes +from mblm.scripts.train_mblm import TrainOutputConfig +from mblm.utils.io import load_model_state, load_yml + +ExpCollection: TypeAlias = list[tuple[str, str]] + + +def extract_prompt_with_offset(text: str, txt_offset: int, prompt_len_bytes: int) -> str: + # because the text offset is computed in UTF-8 and the prompt/context length + # in bytes, convert between the two for slicing + prompt_as_tensor = Bytes.str_to_tensor(text[txt_offset:]) + return Bytes.tensor_to_str(prompt_as_tensor[:prompt_len_bytes]) + + +LINE_BREAK_TABS_RE = re.compile(r"[\r\n]+") + + +def strip_line_breaks(text: str) -> str: + return LINE_BREAK_TABS_RE.sub(" ", text) + + +def load_model( + model_id: str, + model_dir: Path, + device: str, +) -> tuple[MBLM, TrainOutputConfig]: + config_file = model_dir / (model_id + ".yml") + state_file = model_dir / (model_id + ".pth") + + config = load_yml(config_file, TrainOutputConfig) + model = MBLM( + MBLMModelConfig( + num_tokens=config.params.num_tokens, + hidden_dims=tuple(config.params.hidden_dims), + seq_lens=tuple(config.params.seq_lens), + pad_token_id=config.params.pad_token_id, + num_layers=tuple(config.params.num_layers), + train_checkpoint_chunks=config.params.train_checkpoint_chunks, + block=config.params.block, + ) + ).to(device) + + model, _ = load_model_state( + state_file, + model, + map_location=device, + # we've renamed this module + map_rename_modules=(("pos_embs", "patch_pos_embs"),), + ) + return model, config diff --git a/src/mblm/data/dataset/clevr.py b/src/mblm/data/dataset/clevr.py index 687a377..8eb8a5b 100644 --- a/src/mblm/data/dataset/clevr.py +++ b/src/mblm/data/dataset/clevr.py @@ -24,7 +24,7 @@ import os import random from pathlib import Path -from typing import Generator, Literal, TypedDict +from typing import Generator, Literal, TypedDict, overload import torch from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -43,10 +43,18 @@ ) +class ClevrFunction(TypedDict): + function: str + inputs: list[int] + value_inputs: list[str] + + class ClevrQuestion(TypedDict): question: str answer: str image_filename: str + question_family_index: int + program: list[ClevrFunction] class ClevrOptionalArgs(BaseModel): @@ -133,6 +141,25 @@ class Clevr(DistributedDataset[BatchWithLossMask]): "yes": "R", } + QUESTION_TYPES: dict[str, str] = { + "exist": "exists", + "count": "count", + # compare integer + "equal_integer": "compare_integer", + "less_than": "compare_integer", + "greater_than": "compare_integer", + # query attribute + "query_color": "query_attribute", + "query_material": "query_attribute", + "query_size": "query_attribute", + "query_shape": "query_attribute", + # compare attribute + "equal_size": "compare_attribute", + "equal_material": "compare_attribute", + "equal_shape": "compare_attribute", + "equal_color": "compare_attribute", + } + # common (transposed) image shape for all modes IMAGE_SHAPE_C_W_H = 3, 480, 320 MAX_QUESTION_LEN_BYTES = 205 # 203 for validation @@ -218,17 +245,14 @@ def process_and_flatten_img(self, image: ImagePipeline) -> torch.Tensor: image_tensor = image.to_tensor().flatten() return image_tensor - def get_sample_raw(self, from_idx: int) -> tuple[str, str, ImagePipeline]: + def get_sample_raw(self, from_idx: int) -> ClevrQuestion: """ Get a raw sample as a (question, answer, image) tuple with no tokenization or preprocessing applied. """ if self.mode == ModelMode.TEST: raise ValueError("Clevr dataset does not support testing!") - answer_str = self.entries[from_idx]["answer"] - question_str = self.entries[from_idx]["question"] - image_path = self.images_root / self.entries[from_idx]["image_filename"] - return question_str, answer_str, ImagePipeline(image_path, self.image_color_space) + return self.entries[from_idx] def iter_images( self, shuffle: bool = False, max_items: int | None = None @@ -246,13 +270,21 @@ def iter_images( for img_path in image_lst: yield ImagePipeline(self.images_root / img_path, self.image_color_space).to_tensor() + @overload def iter( - self, shuffle: bool = False, max_items: int | None = None - ) -> Generator[tuple[str, str, torch.Tensor], None, None]: + self, *, shuffle: bool = ..., max_items: int | None = ..., raw: Literal[True] + ) -> Generator[tuple[int, ClevrQuestion], None, None]: ... + @overload + def iter( + self, shuffle: bool = ..., max_items: int | None = ..., raw: Literal[False] = ... + ) -> Generator[tuple[int, tuple[str, str, ImagePipeline]], None, None]: ... + def iter( + self, shuffle: bool = False, max_items: int | None = None, raw: bool = False + ) -> Generator[tuple[int, tuple[str, str, ImagePipeline] | ClevrQuestion], None, None]: """ - Iterate over all the question, answer image tuples in Clevr. While - question/answer pairs unique, images may appear more than once. No - preprocessing is applied. + Iterate over all the question, answer image tuples in Clevr (or raw + entries if specified). While question/answer pairs unique, images may + appear more than once. No preprocessing is applied. """ entries_range = list(range(len(self.entries))) if shuffle: @@ -260,8 +292,12 @@ def iter( if max_items is not None: entries_range = entries_range[:max_items] for i in entries_range: - question, answer, image = self.get_sample_raw(i) - yield question, answer, image.to_tensor() + s = self.get_sample_raw(i) + if raw: + yield i, s + else: + img = ImagePipeline(self.images_root / s["image_filename"], self.image_color_space) + yield i, (s["question"], s["answer"], img) def get_sample_with_parts( self, from_idx: int @@ -279,8 +315,11 @@ def get_sample_with_parts( 3. A tuple containing the original `q`, `i`, `a` parts that are concatenated in `q_i_q_a` and the size of the right padding applied """ - question_str, answer_str, image_pipeline = self.get_sample_raw(from_idx) - + s = self.get_sample_raw(from_idx) + question_str, answer_str = (s["question"], s["answer"]) + image_pipeline = ImagePipeline( + self.images_root / s["image_filename"], self.image_color_space + ) # question and image as tensors, not yet tokenized question = Bytes.str_to_tensor(question_str) image = self.process_and_flatten_img(image_pipeline) diff --git a/src/mblm/model/config.py b/src/mblm/model/config.py index f6c990a..dde96c0 100644 --- a/src/mblm/model/config.py +++ b/src/mblm/model/config.py @@ -20,6 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.""" +from enum import Enum, auto from itertools import repeat from typing import Sequence @@ -29,10 +30,16 @@ from mblm.model.transformer import TransformerBlockConfig +class MBLMReturnType(str, Enum): + LOGITS = auto() + LOSS = auto() + LOSS_LOGITS = auto() + + class MBLMModelConfig(BaseModel): """ General config for creating a MBLM model. For all iterables, - the order corresponds to global to most local stage from left to right. + the order corresponds to global to local stage from left to right. Params: num_tokens: The vocabulary size diff --git a/src/mblm/model/embeddings.py b/src/mblm/model/embeddings.py index e053970..19a8453 100644 --- a/src/mblm/model/embeddings.py +++ b/src/mblm/model/embeddings.py @@ -21,7 +21,7 @@ SOFTWARE.""" -MMB_TOKEN_EMB_MIGRATION: set[str] = { +MBLM_TOKEN_EMB_MIGRATION: set[str] = { "token_embs_rev.0.weight", "token_embs_rev.1.0.weight", "to_logits.weight", diff --git a/src/mblm/model/mamba.py b/src/mblm/model/mamba.py index a2b858d..c6a91ae 100644 --- a/src/mblm/model/mamba.py +++ b/src/mblm/model/mamba.py @@ -28,7 +28,7 @@ class MambaBlockConfig(StageBlock, BaseModel): """ - General config for creating a Mamba block inside MMB. + General config for creating a Mamba block inside MBLM. Uses roughly 3 * expand * d_model^2 parameters. Parameters in brackets [x] denote the notation used in Mambabyte diff --git a/src/mblm/model/mblm.py b/src/mblm/model/mblm.py index 63eb2b3..1cec06a 100644 --- a/src/mblm/model/mblm.py +++ b/src/mblm/model/mblm.py @@ -21,7 +21,7 @@ SOFTWARE.""" import math -from typing import cast +from typing import Literal, cast, overload import torch import torch.nn.functional as F # noqa: N812 @@ -31,14 +31,14 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm -from mblm.model.config import MBLMModelConfig +from mblm.model.config import MBLMModelConfig, MBLMReturnType from mblm.model.utils import ByteToUtf8Streamer, RoPE, gumbel_sample, top_k """ Wording: - A hierarchy consists of n stages - - Stage 1 corresponds to to the global model (Yu et al., 2023) - - Stage 2 to n corresponds to local models 2 to n + - Stage 1 to n-1 corresponds to global models + - Stage n corresponds to the local model Inline comment abbreviations: ------------------------------------------------------------------- Global notation @@ -46,16 +46,8 @@ - V Vocabulary size (256 for bytes) - B Batch size - L The input sequence length - - S_n The sequence length/number of patches at any stage n + - P_n The sequence length/number of patches at any stage n - D_n The model dimension at any stage n - ------------------------------------------------------------------- - Derived example notation - ------------------------------------------------------------------- - - D_1 Global model dimension - - S_1 Global model sequence length (maximum) - - S_1' Unpadded / actual global model sequence length: S_1' <= S_1 - - D_2 Dimension of first local model at stage 2 in the hierarchy - - S_2 Patch size of first local model at stage 2 in the hierarchy """ @@ -108,27 +100,27 @@ def __init__(self, cfg: MBLMModelConfig): ) ) - most_local_dim = self.model_dims[-1] # D_n + local_dim = self.model_dims[-1] # D_n # token embeddings are created in reverse order from local to global: # [(V, D_n), ..., (V, D_1)] self.token_embs_rev = nn.ModuleList( - [nn.Embedding(cfg.num_tokens, most_local_dim, padding_idx=self.pad_token_id)] + [nn.Embedding(cfg.num_tokens, local_dim, padding_idx=self.pad_token_id)] ) patch_size = 1 for dim_out, seq_len in zip( - # all except the most local model + # all except the local model reversed(self.model_dims[:-1]), # (D_n-1, ..., D_1) - reversed(self.seq_lens[1:]), # (S_2, ..., S_n) + reversed(self.seq_lens[1:]), # (P_2, ..., P_n) ): patch_size *= seq_len self.token_embs_rev.append( nn.Sequential( - nn.Embedding(cfg.num_tokens, most_local_dim, padding_idx=self.pad_token_id), + nn.Embedding(cfg.num_tokens, local_dim, padding_idx=self.pad_token_id), Rearrange("... r d -> ... (r d)"), - nn.LayerNorm(patch_size * most_local_dim), - nn.Linear(patch_size * most_local_dim, dim_out), + nn.LayerNorm(patch_size * local_dim), + nn.Linear(patch_size * local_dim, dim_out), nn.LayerNorm(dim_out), ) ) @@ -156,7 +148,7 @@ def __init__(self, cfg: MBLMModelConfig): self.to_next_stage_proj.append(proj) - self.to_logits = nn.Linear(most_local_dim, cfg.num_tokens) + self.to_logits = nn.Linear(local_dim, cfg.num_tokens) @torch.inference_mode() def generate( @@ -202,7 +194,7 @@ def generate( sequence = prime for _ in iterator: - logits = self.forward(sequence)[:, -1] + logits = self.forward(sequence, return_type=MBLMReturnType.LOGITS)[:, -1] logits = top_k(logits, thres=filter_thres) sampled = gumbel_sample(logits, dim=-1, temperature=temperature) sequence = torch.cat((sequence, rearrange(sampled, "b -> b 1")), dim=-1) @@ -236,19 +228,36 @@ def forward_empty(self, batch_size: int) -> torch.Tensor: return self.to_logits(tokens) + @overload def forward( self, input_ids: torch.Tensor, - return_loss: bool = False, + *, + return_type: Literal[MBLMReturnType.LOSS_LOGITS] = ..., loss_mask: torch.Tensor | None = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: ... + @overload + def forward( + self, + input_ids: torch.Tensor, + *, + return_type: Literal[MBLMReturnType.LOSS, MBLMReturnType.LOGITS] = ..., + loss_mask: torch.Tensor | None = None, + ) -> torch.Tensor: ... + + def forward( + self, + input_ids: torch.Tensor, + *, + return_type: MBLMReturnType = MBLMReturnType.LOSS, + loss_mask: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ - A single forward pass + A single forward pass. Args: input_ids: The token ids as torch.LongTensor in shape (B, L) - return_loss: If `True`, return the loss, else, return predictions. Predictions - will have the same shape (B, L) as `input_ids`. `False` by default + return_type: What to return - the loss, the logits or both. loss_mask: An optional masking tensor that enables interpolation between self-supervised and supervised learning. It determines which tokens in the prediction should contribute to the loss with what weight and should have @@ -275,49 +284,49 @@ def forward( flat_seq_len = input_ids.shape[-1] # if the input is given as (B, L), reshape and distribute it among the - # local hierarchy sequence lengths, filling up the most local dimensions - # first. padding is applied so that the output shape is (B, S_1', S_2, - # ..., S_n). for the largest possible L == prod(seq_lens), S_1' = S_1. - # in all other cases, S_1' < S_1, meaning no padding is applied to the - # global sequence S_1. + # hierarchy sequence lengths, filling up the inner dimensions + # first. padding is applied so that the output shape is (B, P_1', P_2, + # ..., P_n). for the largest possible L == prod(seq_lens), P_1' = P_1. + # in all other cases, P_1' < P_1, meaning no padding is applied to the + # global sequence P_1. # - # here's two examples for a model model (S_1, S_2, S_3) = (5, 4, 3): + # here's two examples for a model model (P_1, P_2, P_3) = (5, 4, 3): # - # input: (B, L) = (1, 13), output: (B, S_1', S_2, S_3) = (1, 2, 4, 3) - # input: (B, L) = (1, 60), output: (B, S_1, S_2, S_3) = (1, 5, 4, 3) + # input: (B, L) = (1, 13), output: (B, P_1', P_2, P_3) = (1, 2, 4, 3) + # input: (B, L) = (1, 60), output: (B, P_1, P_2, P_3) = (1, 5, 4, 3) # - # in the 2nd example, L = prod(seq_lens) = 5 * 4 * 3 = 60 = S_1 = S_1' + # in the 2nd example, L = prod(seq_lens) = 5 * 4 * 3 = 60 = P_1 = P_1' if flattened_dims: - # pad/fill up all local sequence lengths - local_seq_lens = self.seq_lens[1:] - multiple_of = math.prod(local_seq_lens) + # pad/fill up all inner sequence lengths (all except most global) + inner_seq_lens = self.seq_lens[1:] + multiple_of = math.prod(inner_seq_lens) # use the complement of modulo - the difference to the next multiple # of multiple_of - to infer the right padding length padding = -flat_seq_len % multiple_of input_ids = F.pad(input_ids, (0, padding), value=self.pad_token_id) - # reshape and infer the S_1' dimension - input_ids = input_ids.reshape(batch_size, -1, *local_seq_lens) + # reshape and infer the P_1' dimension + input_ids = input_ids.reshape(batch_size, -1, *inner_seq_lens) - # make sure the above condition holds, i.e., S_1' <= S_1 - _S_1_prime, _S_1 = input_ids.shape[1], self.seq_lens[0] # noqa: N806 + # make sure the above condition holds, i.e., P_1' <= P_1 + _P_1_prime, _P_1 = input_ids.shape[1], self.seq_lens[0] # noqa: N806 fixed_global_patch_encoding = isinstance(self.patch_pos_embs[0], nn.Embedding) if fixed_global_patch_encoding: - assert _S_1_prime <= _S_1, ( + assert _P_1_prime <= _P_1, ( f"Because you are using a fixed global patch embedding, " - f"the input sequence length ({_S_1_prime}) " - f"must be less than the first tuple element of seq_lens ({_S_1})" + f"the input sequence length ({_P_1_prime}) " + f"must be less than the first tuple element of seq_lens ({_P_1})" ) token_embs_at_stages = [torch.empty(0) for _ in range(self.num_stages)] # at this stage, we're working with nested ids - hence, embed the bytes - # for each stage in reverse order, starting from the most local and - # ending at the global model. at each stage, add positional embeddings - # and rerrange the input shape to match the dimension of the previous + # for each stage in reverse order, starting from the local and ending at + # the most global model. at each stage, add positional embeddings and + # rerrange the input shape to match the dimension of the previous # (local) stage. with three stages: # - # [0]: (B, S_1', D_1) - # [1]: (B, S_1', S_2, D_2) - # [2]: (B, S_1', S_2, S_3, D_3) + # [0]: (B, P_1', D_1) + # [1]: (B, P_1', P_2, D_2) + # [2]: (B, P_1', P_2, P_3, D_3) for stage_idx, pos_emb, token_emb in zip( range(self.num_stages - 1, -1, -1), reversed(self.patch_pos_embs), @@ -333,8 +342,8 @@ def forward( stage_token_embs = stage_token_embs + positions elif isinstance(pos_emb, RoPE): batch, *seq_lens, hidden_dim = stage_token_embs.shape - # RoPE expects as input [batch, seq_len, num_heads, head_dim] -> pack to artificial - # batch dim and add an empty head dimension + # RoPE expects as input [batch, seq_len, num_heads, head_dim] -> + # pack to artificial batch dim and add an empty head dimension stage_token_embs = stage_token_embs.view(-1, seq_lens[-1], 1, hidden_dim) stage_token_embs = pos_emb.forward(stage_token_embs) # reshape back to input @@ -365,9 +374,9 @@ def forward( # the first dimension ("*"), after packing, corresponds to the # number of patches in the current stage (which can be considered a - # batch for the stage because they are processed in parallel). in a - # sense, the sequence length of the global model becomes the batch - # size of the first local model, and so on + # batch for the stage because they are processed in parallel). the + # patch size of model n-1 becomes the batch size of model n, and so + # on stage_tokens, ps = pack([stage_emb_tokens], "* s d") patch_batch = stage_tokens.shape[0] # "*" in the pack operation above @@ -385,12 +394,12 @@ def forward( value=0, ) stage_tokens = stage_tokens + prev_stage_tokens_repr - # stage_tokens is now [B, S_n, D_n]. for the first stage, B is the + # stage_tokens is now [B, P_n, D_n]. for the first stage, B is the # actual batch size whereas for the other stages, the batch size # corresponds to the sequence length of the previous hierarchy # stage - # skip checkpointing for the first (=global) stage + # skip checkpointing for the first global stage if checkpoint_chunk is None: attended = model.forward(stage_tokens) @@ -422,12 +431,12 @@ def forward( # restore the initial shape attended = unpack(attended, ps, "* s d")[0] # project for next stage in the hierarchy, dropping the last patch: - # from (..., S_n, D_n) to (..., S_n, S_n+1, D_n+1) + # from (..., P_n, D_n) to (..., P_n, P_n+1, D_n+1) prev_stage_tokens_repr = proj(attended[..., :-1, :]) - logits = self.to_logits.forward(attended) # (B, S_1', S_2, ..., 1 + S_n, V) + logits = self.to_logits.forward(attended) # (B, P_1', P_2, ..., 1 + P_n, V) - if not return_loss: + if return_type == MBLMReturnType.LOGITS: if flattened_dims: # drop the start tokens and combine inner dimensions into one logits = rearrange(logits[..., 1:, :], "b ... v -> b (...) v") @@ -459,7 +468,7 @@ def forward( # same shape as targets with 0 everywhere where the token equals the pad # token. this assumes the same pad token is used for intra-batch padding # (ensured by the datasets/dataloaders) as well as patch-padding - # (ensured by bootstrapping mmb with the right pad token id) + # (ensured by bootstrapping MBLM with the right pad token id) loss_tensor: torch.Tensor = F.cross_entropy( preds, # (B, V, L) targets, # (B, L) @@ -485,5 +494,8 @@ def forward( # special case when the loss is zero across target elements - should # theoretically never happen print("Edge case detected, loss is nan") - return torch.zeros_like(loss) - return loss + loss = torch.zeros_like(loss) + + if return_type == MBLMReturnType.LOSS: + return loss + return loss, logits diff --git a/src/mblm/model/transformer.py b/src/mblm/model/transformer.py index 8fc1ceb..a238208 100644 --- a/src/mblm/model/transformer.py +++ b/src/mblm/model/transformer.py @@ -31,7 +31,7 @@ class TransformerBlockConfig(StageBlock, BaseModel): """ - General config for creating a Transformer block inside MMB. + General config for creating a Transformer block inside MBLM. """ attn_head_dims: int diff --git a/src/mblm/scripts/clevr_generation.py b/src/mblm/scripts/clevr_generation.py new file mode 100644 index 0000000..6aa703d --- /dev/null +++ b/src/mblm/scripts/clevr_generation.py @@ -0,0 +1,185 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + + +from datetime import datetime +from pathlib import Path + +import torch +from pydantic import BaseModel +from tqdm import tqdm + +from mblm import MBLM, MBLMReturnType +from mblm.analysis.utils import load_model +from mblm.data.dataset.clevr import Clevr, ClevrOptionalArgs +from mblm.data.types import ModelMode +from mblm.utils.seed import seed_everything + +DEVICE = "cuda" + +# clevr has no labelled test set +DATASET_MODE = ModelMode.VALID + + +class ClevrModelGeneration(BaseModel): + id_model: str + sample_idx: int + question: str + question_type: str + answer_gen: list[int] + answer_truth: list[int] + ce: float + timestamp: str + + +def sample_clevr_by_question_type(clevr: Clevr, items_per_question: int) -> list[tuple[str, int]]: + """ + Sample clevr indices by question type and return a list of (question_type, + index) tuples. + + From the paper: "We categorize questions by question type, defined by the + outermost function in the question's program" + """ + seed_everything(8) + counts = {q: list[int]() for q in Clevr.QUESTION_TYPES.keys()} + for sample_i, sample in clevr.iter(shuffle=True, raw=True): + q_type = sample["program"][-1]["function"] + if len(q_list := counts[q_type]) < items_per_question: + q_list.append(sample_i) + sample_idxs_flat = [(q, i) for (q, q_idxs) in counts.items() for i in q_idxs] + assert len(sample_idxs_flat) == items_per_question * len(Clevr.QUESTION_TYPES) + return sample_idxs_flat + + +@torch.inference_mode() +@torch.autocast(device_type=DEVICE) +def sample_generation( + model: MBLM, + model_id: str, + clevr: Clevr, + sample_idx: int, + question_type: str, + output_file: Path, +) -> None: + _, _, (question, image, answer, _) = clevr.get_sample_with_parts(sample_idx) + + # reconstruct the prompt + prompt_qiq = torch.concat([question, image, question]).long().to(DEVICE) + + max_tokens_to_generate = len(answer) + + generated_qiqa = model.generate( + prompt_qiq, + temperature=1, + num_tokens_to_generate=max_tokens_to_generate, + enable_progress=False, + ) + # feed the generated bytes back into the model (with additional + # batch dim) to get the loss associated with this generation + loss = model.forward(generated_qiqa.unsqueeze(0), return_type=MBLMReturnType.LOSS) + + # the generated bytes also contain the prompt, strip it + generated_answer = generated_qiqa[len(prompt_qiq) :] + raw_sample = clevr.get_sample_raw(sample_idx) + output = ClevrModelGeneration( + id_model=model_id, + sample_idx=sample_idx, + # not strictly needed - save the answer as a string for convenient + # processing + question=raw_sample["question"], + # From the paper: "We categorize questions by question type, defined by + # the outermost function in the question’s program" + question_type=question_type, + answer_gen=generated_answer.tolist(), + answer_truth=answer.tolist(), + ce=float(loss.item()), + timestamp=str(datetime.now()), + ) + + with output_file.open("a") as f: + f.write(output.model_dump_json() + "\n") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--clevr-dir", + dest="clevr_dir", + type=Path, + ) + parser.add_argument( + "--out-file", + dest="output_file", + type=Path, + ) + parser.add_argument( + "--model-dir", + dest="model_dir", + type=Path, + ) + parser.add_argument( + "-m", + dest="model_id", + action="append", + type=str, + ) + parser.add_argument( + "-n", + dest="num_samples_per_question", + type=int, + ) + + args = parser.parse_args() + output_file: Path = args.output_file + model_dir: Path = args.model_dir + num_samples_per_question: int = args.num_samples_per_question + model_ids: list[str] = args.model_id + clevr_dir: Path = args.clevr_dir + + for model_id in model_ids: + print(f"Model {model_id}") + model, model_config = load_model(model_id=model_id, model_dir=model_dir, device=DEVICE) + model.eval() + clevr_dataset_config = ClevrOptionalArgs.model_validate(model_config.io.dataset_args) + clevr = Clevr( + clevr_dir, + mode=DATASET_MODE, + seq_len=model_config.params.input_seq_len, + pad_token_id=model_config.params.pad_token_id, + optional_args=clevr_dataset_config, + num_workers=1, + worker_id=0, + ) + sample_idxs = sample_clevr_by_question_type(clevr, num_samples_per_question) + + for question_type, sample_idx in (pbar := tqdm(sample_idxs)): + pbar.set_description(f"Sample {sample_idx}") + sample_generation( + model=model, + model_id=model_id, + clevr=clevr, + sample_idx=sample_idx, + question_type=question_type, + output_file=output_file, + ) diff --git a/src/mblm/scripts/clevr_generation.sh b/src/mblm/scripts/clevr_generation.sh new file mode 100644 index 0000000..dc73bb9 --- /dev/null +++ b/src/mblm/scripts/clevr_generation.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +uv run src/mblm/scripts/clevr_generation.py \ + --model-dir ../../disk/inference/mmb_clevr \ + --clevr-dir ../../disk/data/clevr \ + --out-file misc/clevr_qa_pt1.jsonl \ + -m 8k_7mi_360m_1d_s_pt1_ft \ + -m 8k_7mi_360m_1d_s_pt1_nft \ + -m 8k_7mi_360m_1d_t_pt1_ft \ + -m 8k_7mi_360m_1d_t_pt1_nft \ + -n 300 +# uv run src/mblm/scripts/clevr_generation.py \ +# --model-dir ../../disk/inference/mmb_clevr \ +# --clevr-dir ../../disk/data/clevr \ +# --out-file misc/clevr_qa_pt2.jsonl \ +# -m 8k_2mi_360m_1d_s_pt2_disc \ +# -m 8k_2mi_360m_1d_s_pt2_nodisc \ +# -m 8k_2mi_360m_1d_s_pt2_jpeg \ +# -m 8k_2mi_360m_1d_t_pt2_disc \ +# -m 8k_2mi_360m_1d_t_pt2_nodisc \ +# -m 8k_2mi_360m_1d_t_pt2_jpeg \ +# -n 300 \ No newline at end of file diff --git a/src/mblm/scripts/cuda_check.py b/src/mblm/scripts/cuda_check.py new file mode 100644 index 0000000..e67bf5d --- /dev/null +++ b/src/mblm/scripts/cuda_check.py @@ -0,0 +1,28 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +import pprint + +from mblm.utils.cuda import cuda_properties + +if __name__ == "__main__": + pprint.pprint(cuda_properties()) diff --git a/src/mblm/scripts/model_overview.py b/src/mblm/scripts/model_overview.py new file mode 100644 index 0000000..e0b6f0e --- /dev/null +++ b/src/mblm/scripts/model_overview.py @@ -0,0 +1,104 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + + +import re +from itertools import chain +from pathlib import Path +from typing import Sequence + +from tabulate import tabulate + +from mblm import MBLM +from mblm.scripts.train_mblm import TrainEntryConfig +from mblm.utils.io import load_yml +from mblm.utils.misc import count_params + + +def resolve_configs( + file_or_dirs: Sequence[Path], re_filter: str | None = None +) -> list[tuple[Path, TrainEntryConfig]]: + if len(file_or_dirs) == 1 and (config_path := file_or_dirs[0]).is_file(): + return [(config_path, load_yml(config_path, parse_to=TrainEntryConfig))] + yaml_files = chain.from_iterable(p.rglob("*.yaml") for p in file_or_dirs if p.is_dir()) + yml_files = chain.from_iterable(p.rglob("*.yml") for p in file_or_dirs if p.is_dir()) + + pattern = re.compile(re_filter or ".*") + config_files = filter(lambda p: pattern.match(p.name), chain(yaml_files, yml_files)) + return [(path, load_yml(path, parse_to=TrainEntryConfig)) for path in config_files] + + +def print_model_sizes( + file_or_dirs: list[Path], re_filter: str | None = None, count_model_params: bool = False +): + table_data: list[list] = [] + header = ["Config", "Inp. seq len", "Seq. lens", "# Layers"] + if count_model_params: + header += ["Params"] + + configs = sorted( + resolve_configs(file_or_dirs, re_filter), + key=lambda pc: (pc[1].params.input_seq_len, pc[1].params.seq_lens), + ) + for path, conf in configs: + if conf.io.dataset_id != "pg19": + continue + inp_len = conf.params.input_seq_len + seg_lens = conf.params.seq_lens + layers = str(conf.params.num_layers) + row = [path.name, inp_len, seg_lens, layers] + if count_model_params: + model = MBLM(conf.params) + params = f"{count_params(model)[0] / 1e6:.0f}m" + row += [params] + + table_data.append(row) + + print(tabulate(table_data, headers=header, tablefmt="simple_grid")) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "file_or_dir", + type=Path, + nargs="+", + help="Path to a folder with experiments specified as YAML files or path to a single YAML file", + ) + parser.add_argument( + "-p", + dest="count_model_params", + type=bool, + action=argparse.BooleanOptionalAction, + help="Whether or not to count model parameters", + ) + parser.add_argument( + "-f", + dest="regex_filter", + type=str, + help="Regex filter expression on yaml files", + ) + + args = parser.parse_args() + print_model_sizes(args.file_or_dir, args.regex_filter, args.count_model_params) diff --git a/src/mblm/scripts/pg19_generation.py b/src/mblm/scripts/pg19_generation.py new file mode 100644 index 0000000..405908b --- /dev/null +++ b/src/mblm/scripts/pg19_generation.py @@ -0,0 +1,236 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +import re +from datetime import datetime +from pathlib import Path +from typing import cast + +import torch +from pydantic import BaseModel + +from mblm import MBLM, MBLMReturnType +from mblm.analysis.utils import load_model +from mblm.data.dataset.pg19 import PG19 +from mblm.data.types import ModelMode +from mblm.data.utils import Bytes +from mblm.utils.io import NDJSONWriter +from mblm.utils.seed import seed_everything + +DEVICE = "cuda" +DATASET_MODE = ModelMode.VALID +# dot, whitespace, uppercase, 1+ lowercase, whitespace, 1+ lowercase +START_OF_SENTENCE_RE = re.compile(r"\.\s[A-Z][a-z]{1,}\s[a-z]{1,}") + + +class PG19ModelGeneration(BaseModel): + id_model: str + book_id: str + book_txt_offset: int + ctx_len: int + generated: list[int] + truth: list[int] + ce: float + generation_time: float + timestamp: str + + +def seek_to_start_of_sentence(text: str) -> int | None: + start_of_sentence = START_OF_SENTENCE_RE.search(text) + if start_of_sentence: + return start_of_sentence.start() + 2 # ignore dot and whitespace + return None + + +def get_pg19_samples( + pg19: PG19, max_num_samples: int, ctx_len: int +) -> list[tuple[str, int, torch.Tensor]]: + samples: list[tuple[str, int, torch.Tensor]] = [] + seed_everything(8) + for book_id, book_as_str in pg19.iter_books(shuffle=False): + offset = seek_to_start_of_sentence(book_as_str) + if offset is None: + continue + book_as_tensor = Bytes.str_to_tensor(book_as_str[offset:]) + if len(book_as_tensor) >= ctx_len: + samples.append((book_id, offset, book_as_tensor)) + if len(samples) == max_num_samples: + break + + return samples + + +@torch.inference_mode() +@torch.autocast(device_type=DEVICE) +def sample_generation( + output_file: Path, + model: MBLM, + model_id: str, + max_num_samples: int, + ctx_len: int, + generation_len: int, +) -> None: + writer = NDJSONWriter[PG19ModelGeneration](output_file) + + samples = get_pg19_samples(pg19, max_num_samples, ctx_len) + num_available_samples = len(samples) + if max_num_samples < num_available_samples: + print( + f"Warning, context length {ctx_len}: Only {num_available_samples}" + f"available ({max_num_samples} requested)" + ) + else: + print(f"{num_available_samples} samples for context length {ctx_len}") + + for i, (book_id, book_txt_offset, book_as_tensor) in enumerate(samples): + print(f"\t[{i}] Book {book_id}, start time: {datetime.now()}") + + # write a temporary line to indicate we've tried generation + writer.write_line( + PG19ModelGeneration( + id_model=model_id, + book_id=book_id, + book_txt_offset=book_txt_offset, + ctx_len=ctx_len, + generated=[], + truth=[], + ce=-1, + generation_time=-1, + timestamp=str(datetime.now()), + ) + ) + + prompt_len = ctx_len - generation_len + prompt = book_as_tensor[:prompt_len].long().to(DEVICE) + ground_truth = book_as_tensor[prompt_len : prompt_len + generation_len] + + torch.cuda.empty_cache() + + # warmup + _ = model.forward_empty(1) + + start_time = datetime.now() + gen_bytes = model.generate( + prompt, + temperature=0.9, + filter_thres=0.9, + num_tokens_to_generate=generation_len, + ) + generation_time = datetime.now() - start_time + # feed the generated bytes back into the model (with additional + # batch dim) to get the loss associated with this generation + loss_tensor = model.forward(gen_bytes.unsqueeze(0), return_type=MBLMReturnType.LOSS) + loss_tensor = cast(torch.Tensor, loss_tensor) + loss = float(loss_tensor.item()) + generated = gen_bytes[prompt_len:].tolist() + assert len(generated) == generation_len + + # remove the last line and write the actual result + writer.remove_last_line() + writer.write_line( + PG19ModelGeneration( + id_model=model_id, + book_id=book_id, + book_txt_offset=book_txt_offset, + ctx_len=ctx_len, + generated=generated, + truth=ground_truth.tolist(), + ce=loss, + generation_time=generation_time.total_seconds(), + timestamp=str(datetime.now()), + ) + ) + print(f"\t[{i}] Book {book_id}, generation time: {generation_time.total_seconds()}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--pg19-dir", + dest="pg19_dir", + type=Path, + ) + parser.add_argument( + "--out-file", + dest="output_file", + type=Path, + ) + parser.add_argument( + "--model-dir", + dest="model_dir", + type=Path, + ) + parser.add_argument( + "--model-id", + dest="model_id", + type=str, + ) + parser.add_argument( + "--max-num-samples", + dest="max_num_samples", + type=int, + ) + parser.add_argument( + "--generation-len", + dest="generation_len", + type=int, + ) + parser.add_argument( + "--ctx-len", + dest="ctx_len", + action="extend", + nargs="+", + type=int, + ) + + args = parser.parse_args() + model_id: str = args.model_id + output_file: Path = args.output_file + model_dir: Path = args.model_dir + max_num_samples: int = args.max_num_samples + generation_len: int = args.generation_len + ctx_lens: list[int] = args.ctx_len + + pg19 = PG19( + args.pg19_dir, + mode=DATASET_MODE, + seq_len=500_000, # does not strictly matter + num_workers=1, + worker_id=0, + display_load_progress=False, + ) + + print(f"Model: {model_id}, context lengths: {ctx_lens}") + + for ctx_len in ctx_lens: + model, model_config = load_model(model_id=model_id, model_dir=model_dir, device=DEVICE) + model.eval() + sample_generation( + output_file=output_file, + model=model, + model_id=model_id, + max_num_samples=max_num_samples, + ctx_len=ctx_len, + generation_len=generation_len, + ) diff --git a/src/mblm/scripts/pg19_generation.sh b/src/mblm/scripts/pg19_generation.sh new file mode 100644 index 0000000..68946e6 --- /dev/null +++ b/src/mblm/scripts/pg19_generation.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# we check the following model ids: +# 8k_30b_360m_1d_ssm +# 8k_30b_360m_1d_t_nopos +# 100k_200b_360m_2d_ss +# 100k_200b_360m_2d_st + +# for a quick debugging run, you can execute: +# RUN_PREFIX=test CTX_LEN="8192 32768" NUM_SAMPLES=1 GEN_LEN=3 bash src/multiscale_mambabyte/scripts/pg19_generation.sh 8k_30b_360m_1d_ssm + +# check if at least one model ID is provided +if [ "$#" -lt 1 ]; then + echo "Usage: $0 [ ...]" + exit 1 +fi + +# read model IDs from command line arguments +MODEL_IDS=("$@") + +# defaults that can be overwritten via env variables +RUN_PREFIX=${RUN_PREFIX:-""} +NUM_SAMPLES=${NUM_SAMPLES:-10} +GEN_LEN=${GEN_LEN:-512} +# these need to be sorted in ascending order - as soon as we run oom for a ctx +# len, we assume the larger ctx lens are also oom. calculated with: +# np.concat( +# [ +# np.array([8192]) * np.arange(1, 5, step=1) ** 2, +# np.array([8192]) * np.arange(5, 12, step=2) ** 2, +# ] +# ) +# +CTX_LEN=${CTX_LEN:-"8192 32768 73728 131072 204800 401408 663552 991232"} + +echo "--------------------------" +echo "Model ids: ${MODEL_IDS[@]}" +echo "Run prefix: ${RUN_PREFIX}" +echo "Num samples: ${NUM_SAMPLES}" +echo "Ctx lens: ${CTX_LEN}" +echo "Generation len samples: ${GEN_LEN}" +echo "--------------------------" + +if [[ -n "$RUN_PREFIX" ]]; then + # add underscore if run prefix is set + RUN_PREFIX="${RUN_PREFIX}_" +fi + +for model_id in "${MODEL_IDS[@]}"; do + uv run \ + src/mblm/scripts/pg19_generation.sh \ + --model-dir ../../disk/inference/mmb_pg19 \ + --pg19-dir ../../disk/data/pg19 \ + --out-file misc/gen_pg19/${RUN_PREFIX}${model_id}.jsonl \ + --model-id $model_id \ + --max-num-samples ${NUM_SAMPLES} \ + --generation-len ${GEN_LEN} \ + --ctx-len ${CTX_LEN} + + if [ $? -ne 0 ]; then + echo "$model_id ran out of memory, continuing with next model" + continue + fi +done diff --git a/src/mblm/scripts/train_launch.sh b/src/mblm/scripts/train_launch.sh new file mode 100644 index 0000000..65d6cde --- /dev/null +++ b/src/mblm/scripts/train_launch.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +if [ -z "$1" ]; then + echo "Error: No config path provided" + echo "Usage: $0 " + exit 1 +fi + +echo "Launching MBLM training from config file:" +echo $1 + +export JOB_ID=$(date +%s) +export DISPLAY_PROGRESS=1 + +# add the --standalone flag for single-node +OMP_NUM_THREADS=1 uv run torchrun \ + --standalone \ + --nproc_per_node=gpu \ + src/mblm/scripts/train_mblm.py \ + -c $1 \ No newline at end of file diff --git a/src/mblm/scripts/train_mblm.py b/src/mblm/scripts/train_mblm.py new file mode 100644 index 0000000..56970c5 --- /dev/null +++ b/src/mblm/scripts/train_mblm.py @@ -0,0 +1,220 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +import math +import os +from pathlib import Path +from typing import Any, Iterator, Literal + +import torch +from torch.distributed.elastic.multiprocessing.errors import record +from torch.optim import Adam, Optimizer # type: ignore +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR + +from mblm import MBLM, MBLMModelConfig, MBLMReturnType +from mblm.data.dataset.clevr import Clevr, ClevrOptionalArgs +from mblm.data.dataset.pg19 import PG19 +from mblm.data.datasets import BatchWithLossMask +from mblm.data.types import ModelMode +from mblm.model.embeddings import MBLM_TOKEN_EMB_MIGRATION +from mblm.trainer.config import ( + CoreIoConfig, + CoreModelParams, + CoreTrainConfig, + GenericEntryConfig, + GenericOutputConfig, +) +from mblm.trainer.core import CoreTrainer +from mblm.utils.distributed import process_group +from mblm.utils.io import load_yml +from mblm.utils.logging import create_logger, shutdown_log_handlers +from mblm.utils.misc import count_params + + +class IoConfig(CoreIoConfig): + """ + Custom io settings on top of the core/required parameters + """ + + dataset_dir: str + dataset_id: Literal["pg19", "clevr"] + dataset_args: dict[str, Any] | None = None + description: str | None = None + + +class ModelParams(MBLMModelConfig, CoreModelParams): + """ + Combine the params required by the MBLM model and the trainer. + """ + + pass + + +class TrainOutputConfig(GenericOutputConfig[ModelParams, CoreTrainConfig, IoConfig]): + """ + A class that can be used directly to parse any output generated from + training with the MBLM model. + """ + + pass + + +class TrainEntryConfig(GenericEntryConfig[ModelParams, CoreTrainConfig, IoConfig]): + def import_dataset(self, mode: ModelMode, worker_id: int, num_workers: int): + if self.io.dataset_id == "clevr": + # cannot pass None to model_validate + optional_args = ClevrOptionalArgs.model_validate(self.io.dataset_args or dict()) + + return Clevr( + data_dir=self.io.dataset_dir, + mode=mode, + pad_token_id=self.params.pad_token_id, + seq_len=self.params.input_seq_len, + worker_id=worker_id, + num_workers=num_workers, + optional_args=optional_args, + ) + + return PG19( + data_dir=self.io.dataset_dir, + mode=mode, + seq_len=self.params.input_seq_len, + worker_id=worker_id, + num_workers=num_workers, + ) + + +class MegabyteTrainer(CoreTrainer[MBLM, BatchWithLossMask, ModelParams, CoreTrainConfig, IoConfig]): + def init_model(self): + return MBLM( + MBLMModelConfig( + # number of tokens + num_tokens=self.config.params.num_tokens, + # transformer model dimension (global, local) + hidden_dims=tuple(self.config.params.hidden_dims), + # sequence length (global, local) + seq_lens=tuple(self.config.params.seq_lens), + pad_token_id=self.config.params.pad_token_id, + num_layers=tuple(self.config.params.num_layers), + train_checkpoint_chunks=self.config.params.train_checkpoint_chunks, + block=self.config.params.block, + ) + ) + + def model_forward(self, model, batch, device) -> torch.Tensor: + inputs, loss_mask = batch + inputs = inputs.to(device) + loss_mask = loss_mask.to(device) + loss: torch.Tensor = model.forward( + inputs, return_type=MBLMReturnType.LOSS, loss_mask=loss_mask + ) + return loss + + def configure_optimizer(self, parameters: Iterator[torch.nn.Parameter]) -> Optimizer: + return Adam( + parameters, + lr=self.config.train.learning_rate, + betas=(0.9, 0.95), + ) + + def configure_scheduler(self, optimizer, local_gradient_steps) -> LRScheduler: + warmup_steps = math.floor(local_gradient_steps * self.config.train.warmup_steps_perc) + linear = LinearLR( + optimizer, + total_iters=warmup_steps, + start_factor=0.1, + end_factor=1, + ) + cosine_iters = local_gradient_steps - warmup_steps + cosine = CosineAnnealingLR(optimizer, T_max=cosine_iters) + return SequentialLR( + optimizer, + [linear, cosine], + milestones=[warmup_steps], + ) + + def configure_run_id(self) -> str: + return os.getenv("JOB_ID") or super().configure_run_id() + + def configure_count_parameters(self, model): + return count_params(model) + + def migrate_embeddings_if_enabled(self): + # older versions - the pg19 pretrained models - of mblm may have been trained + # without modality tokens - provide this map to migrate the embeddings. + # enabled via yaml config + return MBLM_TOKEN_EMB_MIGRATION + + def rename_modules_if_enabled(self): + # we have renamed pos_embs to patch_pos_embs in newer versions. enabled + # via yaml config + return (("pos_embs", "patch_pos_embs"),) + + +@record +def main(config: TrainEntryConfig) -> None: + log = create_logger(__name__, log_dir=config.io.output_dir) + + try: + with process_group(backend="nccl") as run_vars: + train_dataset = config.import_dataset( + mode=ModelMode.TRAIN, + worker_id=run_vars.local_rank, + num_workers=run_vars.world_size, + ) + valid_dataset = config.import_dataset( + mode=ModelMode.VALID, + worker_id=0, + num_workers=1, + ) + + trainer = MegabyteTrainer(config, run_vars=run_vars) + best_model = trainer.train(train_dataset, valid_dataset) + + supports_test_mode = train_dataset.supports_test_mode() + if best_model and supports_test_mode: + test_dataset = config.import_dataset( + mode=ModelMode.TEST, + worker_id=0, + num_workers=1, + ) + trainer.test(test_dataset, best_model) + except Exception as error: + log.fatal(error, exc_info=True) + shutdown_log_handlers() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + dest="config_path", + required=True, + type=Path, + help="Path to the experiment yaml config file", + ) + + args = parser.parse_args() + train_cfg = load_yml(args.config_path, parse_to=TrainEntryConfig) + main(train_cfg) diff --git a/src/mblm/trainer/__init__.py b/src/mblm/trainer/__init__.py new file mode 100644 index 0000000..cd0fc30 --- /dev/null +++ b/src/mblm/trainer/__init__.py @@ -0,0 +1,21 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" diff --git a/src/mblm/trainer/config.py b/src/mblm/trainer/config.py new file mode 100644 index 0000000..6bf7a13 --- /dev/null +++ b/src/mblm/trainer/config.py @@ -0,0 +1,204 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +""" +This module defines the core attributes that need to be specified for an +experiment. More granular training runs may subclass the core classes and add +additional attributes. +""" + +from typing import Generic, Literal, NamedTuple, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + +""" +The return type of any DistributedDataset that is used with the Core trainer - +the tuple indicates , and target can be None for +self-supervised learning +""" + + +class CoreModelParams(BaseModel): + """ + The core model parameters that all models must specify when training with + the `CoreTrainer`. + """ + + input_seq_len: int = Field( + description="The input sequence length. This can be regarded as the context size when training for language. In other scenarios like linear regression, this might be a single data point, hence 1" + ) + + +class CoreTrainConfig(BaseModel): + """ + The core training parameters needed for a training run with the + `CoreTrainer`. + """ + + target_elements: int = Field( + description="The desired number of data points to train on. If `None`, defaults to using data points from the training set once, resulting in a single epoch. Note that this is a *lower bound* - due to the sequence length and batch sizes, in effect, we train on more than this target. Use this when you want to train on a fixed subset of data, e.g., a number of bytes" + ) + target_elements_strategy: Literal["batch", "sequence"] = Field( + description="The strategy to count elements in a batch, used to determine when `target_elements` is achieved. 'batch' means batch size. 'sequence' will count each element in the batch as contributing to the `target_elements`. E.g., when a batch has `n` sequences of length `L`, then the 'sequence' strategy will count `n` * `L` target elements per batch" + ) + warmup_steps_perc: float = Field( + default=0.1, + description="A float in the range [0, 1] to determine how many of the total gradient steps should be used for a warmup. The rest of the steps follows cosine annealing", + ) + batch_size: int = Field(description="The batch size") + + shuffle_train: bool = Field( + default=False, + description="Shuffle the training data in the data loader. Enable this only if you are sure you're not chaining runs", + ) + shuffle_eval: bool = Field( + default=False, description="Shuffle the validation data in the data loader" + ) + max_eval_steps: int | None = Field( + default=None, description="Maximum number of evaluation iterations" + ) + learning_rate: float = Field(description="The learning rate") + gradient_clipping: float | None = Field( + default=None, description="A value for gradient clipping" + ) + gradient_accumulate_every: int = Field( + description="After how many batches to accumulate the gradient" + ) + + +class ResumeConfig(BaseModel): + """ + The resume parameters needed for to resume training from a checkpoint and a + specific epoch and batch with `CoreTrainer`. In case training starts for the + first time, this class is not needed. For a completed training run, it is + automatically populated so that it can be used in a future training run. + """ + + checkpoint_file: str = Field(description="Path to model state checkpoint") + next_epoch_index: int = Field(description="The epoch to resume from") + next_batch_index: int = Field(description="The global batch counter to resume from") + migrate_embeddings: bool = Field( + default=False, description="Migrate an existing smaller number of embeddings" + ) + rename_modules: bool = Field( + default=False, description="Rename modules from existing models to new names" + ) + resumed_from: str | None = Field( + default=None, + description="If training has been resumed from a previous checkpoint/experiment, points to the config of that experiment to easily trace chained training runs", + ) + + +class CoreIoConfig(BaseModel): + """ + The core input/output parameters needed for a training run with the + `CoreTrainer`. + """ + + name_model: str = Field(description="Model name for saving checkpoints") + output_dir: str = Field( + description="The output directory for all artefacts (will be created automatically based on `model_name` and a unique postfix)" + ) + num_models_to_save: int = Field( + description="Max number of best performing models to store. If smaller than `validate_amount`, stores only `validate_amount` models" + ) + validate_amount: int = Field( + description="How often (in total) to run the validation set and reserve a model candidate. **Must be >= than `num_models_to_save`**" + ) + log_train_loss_amount: int = Field(description="How often (in total) to log training loss") + enabled_loss_log_for_gpus: list[int] = Field( + default=[0], description="The rank of the GPUs that should write to the CSV loss file" + ) + + +TModelParams = TypeVar("TModelParams", covariant=True, bound=CoreModelParams) +TTrainConfig = TypeVar("TTrainConfig", covariant=True, bound=CoreTrainConfig) +TIoConfig = TypeVar("TIoConfig", covariant=True, bound=CoreIoConfig) + + +class SummaryStats(BaseModel): + """ + Summary stats for a training run + """ + + training_start: str + training_end: str + num_workers: int + cuda_devices: list[str] + parameter_count: int + error: str | None + + +class GenericEntryConfig(BaseModel, Generic[TModelParams, TTrainConfig, TIoConfig]): + """ + A generic entry config for training with the trainer that can be made more + specific by subclassing. + """ + + model_config = ConfigDict(strict=True) + + params: TModelParams + train: TTrainConfig + io: TIoConfig + resume: ResumeConfig | None = None + + +class GenericOutputConfig(BaseModel, Generic[TModelParams, TTrainConfig, TIoConfig]): + """ + A generic output config for a training run with the trainer that can be made + more specific by subclassing. + """ + + model_config = ConfigDict(strict=True) + + params: TModelParams + train: TTrainConfig + io: TIoConfig + resume: ResumeConfig + summary: SummaryStats + + +class CSVLossEntry(NamedTuple): + gpu_rank: int + timestamp: str + elements_seen: int + kind: str + epoch: int + batch: int + cum_batch: int + loss: float + lr: float + avg_grad: float + avg_grad_clipped: float + + +class CSVTimeAndMemSnapshotEntry(NamedTuple): + cum_batch: int + num_items: int + kind: str + fw_time: float + bw_time: float | None + allocated: float + allocated_max: float + reserved: float + reserved_max: float + total: float diff --git a/src/mblm/trainer/core.py b/src/mblm/trainer/core.py new file mode 100644 index 0000000..c90c28f --- /dev/null +++ b/src/mblm/trainer/core.py @@ -0,0 +1,1002 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +import logging +import math +import pprint +import sys +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from time import time +from typing import Any, Generic, Iterable, Iterator, Literal, TypeVar, cast + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer # type: ignore +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm + +from mblm.data.datasets import DistributedDataset +from mblm.data.types import ModelMode +from mblm.trainer.config import ( + CoreIoConfig, + CSVLossEntry, + CSVTimeAndMemSnapshotEntry, + GenericEntryConfig, + GenericOutputConfig, + ResumeConfig, + SummaryStats, + TIoConfig, + TModelParams, + TTrainConfig, +) +from mblm.trainer.iter import epoch_cycler +from mblm.utils.cuda import cuda_memory_snapshot, cuda_properties +from mblm.utils.distributed import ElasticRunVars +from mblm.utils.io import CSVWriter, StateDict, dump_yml, load_model_state, save_model_state +from mblm.utils.logging import create_logger +from mblm.utils.misc import count_params +from mblm.utils.retry import retry +from mblm.utils.top_n import TopN + +TModel = TypeVar("TModel", bound=torch.nn.Module) +TBatch = TypeVar("TBatch") + + +@dataclass +class CoreTrainerOptions: + config_file_name: str = "config.yml" + loss_file_name: str = "loss.csv" + timemem_file_name: str = "timemem.csv" + max_train_restarts: int = 0 + skip_validation: bool = False + display_progress: bool = sys.stdout.isatty() + train_prog_min_interval_seconds: int = 1 + valid_prog_min_interval_seconds: int = 1 + track_first_fw_bw_exec_times: int | None = 30 # for 30 first passes, track fw/bw time + amp_dtype: torch.dtype = torch.half # may use bfloat16 + + +class CoreTrainer(ABC, Generic[TModel, TBatch, TModelParams, TTrainConfig, TIoConfig]): + """ + An abstract core trainer that provides a set of utility methods for + training, evaluating and testing. It is held generic to enforce type-safety + when implementing the abstract methods. + + - All methods that may or may not be implemented are `@classmethod` + + - Methods that _must_ be implemented are concerned with creating the + model, which is the job of the instantiator. They are abstract and + type-checkers complain if they are not implemented correctly + + - Methods that _may_ be implemented, i.e., overwritten, have the prefix + `with_`. For example, for the optimizer, an `Adam` optimizer with sensible + defaults is provided but it can be overwritten + + """ + + # public config + config: GenericEntryConfig[TModelParams, TTrainConfig, TIoConfig] + + # overridable options with sensible defaults + options: CoreTrainerOptions + + # private var + _local_rank: int + _world_size: int + _is_main_worker: bool + _device: str + _device_type: Literal["cuda", "cpu"] + _is_cuda: bool + + _model_dist: DistributedDataParallel + + # misc - created internally + _output_dir: Path + _running_resume_conf: ResumeConfig + _running_summary_stats: SummaryStats + _top_n_models: TopN[StateDict] + _csv_loss_writer: CSVWriter[CSVLossEntry] + _csv_timemem_writer: CSVWriter[CSVTimeAndMemSnapshotEntry] + _log: logging.Logger + + def __init__( + self, + config: GenericEntryConfig[TModelParams, TTrainConfig, TIoConfig], + run_vars: ElasticRunVars, + options: CoreTrainerOptions | None = None, + ): + self.config = config + self.options = options or CoreTrainerOptions() + self._world_size = run_vars.world_size + self._local_rank = run_vars.local_rank + # used for sending tensors/models to a device + self._device = f"cuda:{self._local_rank}" if run_vars.is_cuda else "cpu" + # used for mixed-precision + self._device_type = "cuda" if run_vars.is_cuda else "cpu" + self._is_cuda = not self._device == "cpu" + self._is_main_worker = run_vars.local_rank == 0 + + self._output_dir = self._create_output_dir(config.io) + + self._log = self.configure_logger(self._output_dir, self._is_main_worker) + self._top_n_models = TopN( + config.io.num_models_to_save, + deep_copy=True, # module state_dicts are references + ) + + # the ranks of the gpus that should write to the csv loss file + gpu_rank_csv_loss = set(self.config.io.enabled_loss_log_for_gpus) + self._csv_loss_writer = CSVWriter( + self._output_dir, + self.options.loss_file_name, + noop=self._local_rank not in gpu_rank_csv_loss, + ) + self._csv_timemem_writer = CSVWriter( + self._output_dir, self.options.timemem_file_name, noop=not self._is_main_worker + ) + + assert config.io.validate_amount > 0, "Validate amount must be strictly positive" + assert config.io.num_models_to_save > 0, "Must save at least 1 model" + + if config.io.validate_amount < config.io.num_models_to_save: + self._log.warning( + f"Validate amount ({ config.io.validate_amount}) \ + is less than number of models to save ({ config.io.num_models_to_save}).\ + Saving only { config.io.validate_amount} models" + ) + + model = self.init_model().to(self._device) + if config.resume: + self._log.info("Initiating model loading from checkpoint") + map_extend_embeddings = ( + self.migrate_embeddings_if_enabled() if config.resume.migrate_embeddings else None + ) + map_rename_modules = ( + self.rename_modules_if_enabled() if config.resume.rename_modules else None + ) + model, model_loss = load_model_state( + config.resume.checkpoint_file, + model, + map_location=self._device, + map_extend_embeddings=map_extend_embeddings, + map_rename_modules=map_rename_modules, + on_success=self._log.debug, + ) + # save the restored model to the top n to make sure we keep the best + # model from the previous run should we only make everything worse + # during this training + self._top_n_models.add((model_loss, model.state_dict())) + self._log.info(f"Loaded model with loss {model_loss:.4f} from checkpoint") + else: + self._log.info("Creating new model") + + self._model_dist = self._init_distributed_model(model) + + # initialize the running resume/summary configs that are updated on the fly + self._running_resume_conf = ResumeConfig( + checkpoint_file="", + next_batch_index=-1, + next_epoch_index=-1, + resumed_from=config.resume.checkpoint_file if config.resume else None, + ) + cuda_info = cuda_properties() + main_model_params, submodule_params = self.configure_count_parameters(model) + + self._running_summary_stats = SummaryStats( + parameter_count=main_model_params, + num_workers=run_vars.world_size, + cuda_devices=cuda_info.cuda_devices, + training_start="", # temporary, updated when training starts + training_end="", # temporary, updated when training ends + error=None, # temporary, updated when training fails + ) + self._dump_output_config() + self._log.info("Trainer initialized successfully") + self._log.info(f"Model parameters: {main_model_params}, ({submodule_params})") + self._log.info(f"Configuration: {pprint.pformat((config), indent=4)}") + self._log.info(f"CUDA: {pprint.pformat(cuda_info)}") + + """ Abstract methods that must be implemented """ + + @abstractmethod + def init_model(self) -> TModel: + """ + Initialize a model of the specified type `TModel`. + """ + ... + + @abstractmethod + def model_forward( + self, + model: TModel, + batch: TBatch, + device: str, + ) -> torch.Tensor: + """ + A single forward pass of the model. Both the model and batch are + generic, their types are inferred according to the type instantiation + defined when subclassing `CoreTrainer`. + + Args: + model (TModel): The model (already on the device) + batch (TBatch): One batch (MUST be put to device) + device (str): `cuda:n` for the `n`-th GPU or `cpu` + + Returns: + torch.Tensor: A Tensor with a single element that is the loss + of the forward pass + + **Example**:: + + # here, batch is a tuple of data, target + @classmethod + def model_forward(cls, model, batch, device): + x, y = batch + output = model.forward(x.to(device)) + loss_function = torch.nn.MSELoss() + loss: torch.Tensor = loss_function(output, y) + return loss + + + # in other scenarios, batch might be a single Tensor + @classmethod + def model_forward(cls, model, batch, device): + batch = batch.to(device).long() + loss: torch.Tensor = model.forward(batch, return_loss=True) + return loss + """ + + ... + + @abstractmethod + def configure_optimizer(self, parameters: Iterator[torch.nn.Parameter]) -> Optimizer: + """ + Configure an optimizer + """ + ... + + """ Default methods that can be overwritten """ + + def configure_scheduler(self, optimizer: Optimizer, local_gradient_steps: int) -> LRScheduler: + """ + Configure a LR scheduler. + + Args: + optimizer (Optimizer): The optimizer + local_gradient_steps (int): The total number of gradient steps for this GPU. + """ + return torch.optim.lr_scheduler.PolynomialLR( + optimizer, + total_iters=local_gradient_steps, + power=1.0, + ) + + def configure_logger(self, output_dir: Path, is_main_worker: bool) -> logging.Logger: + """ + Customize the logger + """ + return create_logger( + name="train", + log_dir=output_dir, + # all non-main workers are noop loggers + noop=not is_main_worker, + ) + + def configure_count_parameters(self, model: TModel) -> tuple[int, dict[str, int]]: + """ + Determine how to count parameters for this model + """ + return count_params(model) + + def configure_run_id(self) -> str: + """ + Set a unique identifier for this experiment. Used as postfix for the + output directory. + """ + return f"{time():.0f}" + + def migrate_embeddings_if_enabled(self) -> set[str] | None: + """ + When resuming training from a model, a smaller number of embeddings can + be migrated to a larger number. This can be enabled via the resume + config. + """ + return None + + def rename_modules_if_enabled(self) -> Iterable[tuple[str, str]] | None: # noqa: ARG003 + """ + When resuming training from a model where modules are named differently, + provide a map in the form (source_prefix, target_prefix) to override + module names. + """ + return None + + """ Utility functions """ + + def _init_distributed_model(self, base_model: TModel) -> DistributedDataParallel: + """ + Create a distributed version of the model. + """ + + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(base_model) + # for multi-device modules and CPU modules, device_ids must be None + device_ids = [self._local_rank] if self._is_cuda else None + return DistributedDataParallel(model, device_ids=device_ids) + + def _unpack_distributed_model(self, module: TModel | DistributedDataParallel) -> TModel: + if isinstance(module, DistributedDataParallel): + return module.module + return module + + def _create_output_dir(self, io_config: CoreIoConfig) -> Path: + """ + Create a unique output directory (only on main worker). + """ + run_id = self.configure_run_id() + output_dir = Path(io_config.output_dir) / f"{io_config.name_model}_{run_id}" + + # only the main worker should do i/o + if self._is_main_worker: + output_dir.mkdir(parents=True, exist_ok=True) + + return output_dir + + def _dump_output_config(self): + """ + Dump all config files to disk (only on main worker). + """ + if not self._is_main_worker: + return + # copy the config over into the output format + output_config = GenericOutputConfig( + io=self.config.io, + params=self.config.params, + train=self.config.train, + resume=self._running_resume_conf, + summary=self._running_summary_stats, + ) + dump_yml(self._output_dir / self.options.config_file_name, output_config) + + def _write_csv_loss( + self, + kind: ModelMode, + loss: float, + epoch: int, + batch: int, + cum_batch: int, + elements_seen: int, + lr: float, + avg_grad: float, + avg_grad_clipped: float, + ) -> None: + # no need to check for main worker - the writer has been initialized + # before so that only the main worker performs io + row = CSVLossEntry( + gpu_rank=self._local_rank, + timestamp=str(datetime.now()), + kind=kind.value, + elements_seen=elements_seen, + epoch=epoch, + batch=batch, + cum_batch=cum_batch, + loss=loss, + lr=lr, + avg_grad=avg_grad, + avg_grad_clipped=avg_grad_clipped, + ) + self._csv_loss_writer.write_row(row) + + def _write_csv_timemem( + self, cum_batch: int, num_items, kind: ModelMode, fw_time: float, bw_time: float | None + ): + mem_snapshot = cuda_memory_snapshot(self._device) + row = CSVTimeAndMemSnapshotEntry( + kind=kind.value, + num_items=num_items, + cum_batch=cum_batch, + fw_time=fw_time, + bw_time=bw_time, + allocated=mem_snapshot.allocated, + allocated_max=mem_snapshot.allocated_max, + reserved=mem_snapshot.reserved, + reserved_max=mem_snapshot.reserved_max, + total=mem_snapshot.total, + ) + self._csv_timemem_writer.write_row(row) + + def _save_best_models(self) -> tuple[int, int, Path]: + num_written = 0 + num_overwritten = 0 + best_checkpoint = Path() + if not self._is_main_worker: + return num_written, num_overwritten, best_checkpoint + + # save final n best models - best models are iterated first + for idx, (loss, model_state) in enumerate(self._top_n_models): + did_overwrite, checkpoint_path = save_model_state( + self._output_dir, + f"{self.config.io.name_model}_top{idx+1}.pth", + model=model_state, + loss=loss, + ) + num_written += 1 + if idx == 0: + best_checkpoint = checkpoint_path + if did_overwrite: + num_overwritten += 1 + return num_written, num_overwritten, best_checkpoint + + def _log_cuda_memory_snapshot(self, cumulative_batch_idx: int | None) -> None: + if self._is_cuda: + snapshot = cuda_memory_snapshot(self._device) + prefix = f"[{cumulative_batch_idx}] " if cumulative_batch_idx else "" + self._log.debug(f"{prefix}CUDA memory: {snapshot}") + + def _calc_logging_points( + self, + total_batch_iters: int, + start_batch_idx: int, + ) -> tuple[set[int], set[int]]: + """ + Calculate the indices for logging the training loss and validate based + on the total amount of batch iterations and offset start batch index + (when resuming training) + """ + log_train_loss_amount = self.config.io.log_train_loss_amount + if total_batch_iters < log_train_loss_amount: + self._log.warning( + f"Less batch iterations ({total_batch_iters}) " + f"than number of train loss log points ({log_train_loss_amount}). " + f"Clipping train loss log points to {total_batch_iters}" + ) + log_train_loss_amount = total_batch_iters + + validate_amount = self.config.io.validate_amount + if total_batch_iters < validate_amount: + self._log.warning( + f"Less batch iterations ({total_batch_iters}) " + f"than number of validation runs ({validate_amount}). " + f"Clipping validation points to {total_batch_iters}" + ) + validate_amount = total_batch_iters + log_train_loss_idxs = set( + torch.linspace( + start_batch_idx, + start_batch_idx + total_batch_iters - 1, + log_train_loss_amount, + ) + .long() + .tolist() + ) + run_valid_interval_idxs = set( + torch.linspace( + start_batch_idx, + start_batch_idx + total_batch_iters - 1, + validate_amount, + ) + .long() + .tolist() + ) + + return log_train_loss_idxs, run_valid_interval_idxs + + def _save_training_state(self, batch_i: int, epoch: int): + if not self._is_main_worker: + return + num_written, num_overwritten, best_model_path = self._save_best_models() + self._log.debug( + f"Saved {num_written} best model(s) " f"(overwrote {num_overwritten})", + ) + + # mutate running config in place, then save back + self._running_resume_conf.next_batch_index = batch_i + self._running_resume_conf.next_epoch_index = epoch + self._running_resume_conf.checkpoint_file = str(best_model_path) + + self._dump_output_config() + self._log.debug(f"Saved training state at epoch {epoch}, batch {batch_i}") + + def avg_gradient_value(self) -> float: + gradients = [ + p.grad.mean().item() for p in self._model_dist.parameters() if p.grad is not None + ] + return sum(gradients) / len(gradients) + + """ Training and evaluation """ + + def _evaluate( + self, + model: torch.nn.Module, + loader: DataLoader[TBatch], + items_seen_so_far: int, + cumulative_batch_idx: int, + ) -> float: + """ + Evaluate any model on any dataset. + """ + model.eval() + loss = 0.0 + time_taken = 0.0 + target_iters = ( + min(self.config.train.max_eval_steps, len(loader)) + if self.config.train.max_eval_steps + else len(loader) + ) + for it, batch in enumerate( + tqdm( + loader, + total=target_iters, + desc="Evaluating", + leave=False, + disable=not self.options.display_progress, + mininterval=self.options.valid_prog_min_interval_seconds, + ) + ): + if it == target_iters: + break + with torch.autocast( + device_type=self._device_type, + dtype=self.options.amp_dtype, + ): + with torch.inference_mode(): + start_eval = time() + loss_tensor = self.model_forward( + cast(TModel, model), + batch=batch, + device=self._device, + ) + eval_time = time() - start_eval + loss += float(loss_tensor.item()) + # for the eval dataloader, we don't drop the last batch, + # hence, the last batch might have a lower batch size. + # therefore, count manually to report accurate times per + # element + time_taken += eval_time + + if self.options.track_first_fw_bw_exec_times: + self._write_csv_timemem( + cum_batch=cumulative_batch_idx, + kind=ModelMode.VALID, + fw_time=time_taken, + bw_time=None, + num_items=items_seen_so_far, + ) + return loss / target_iters + + def train( + self, + train_dataset: DistributedDataset[TBatch], + valid_dataset: DistributedDataset[TBatch], + ) -> TModel | None: + self._running_summary_stats.training_start = datetime.now().isoformat() + + def on_error(error: Exception, retries_left: int): + self._log.fatal( + f"Training failed, {retries_left}/{self.options.max_train_restarts} retries left" + ) + self._log.fatal(error, exc_info=True) + self._running_summary_stats.error = str(error) + if self.options.display_progress: + print(error) + + train_with_retry = retry(self.options.max_train_restarts, on_error=on_error)(self._train) + best_model = train_with_retry(train_dataset, valid_dataset) + + self._running_summary_stats.training_end = datetime.now().isoformat() + self._log_cuda_memory_snapshot(-99) + self._dump_output_config() + return best_model + + def _get_dataloader( + self, + dataset: DistributedDataset[TBatch], + data_loader_kwargs: dict[str, Any], + **additional_data_loader_kwargs: dict[str, Any], + ) -> DataLoader: + """Generic data loader instantiation. + + Args: + dataset: a distributed dataset object. + data_loader_kwargs: additional arguments for the data loader. + + Returns: + a data loader. + """ + return DataLoader(dataset, **{**data_loader_kwargs, **additional_data_loader_kwargs}) + + def get_train_dataloader( + self, dataset: DistributedDataset[TBatch], **additional_data_loader_kwargs: dict[str, Any] + ) -> DataLoader: + """Train data loader instantiation. + + Args: + dataset: a distributed dataset object. + data_loader_kwargs: additional arguments for the data loader. + + Returns: + the train data loader. + """ + return self._get_dataloader( + dataset=dataset, + data_loader_kwargs=dict( + batch_size=self.config.train.batch_size, + pin_memory=True, + shuffle=self.config.train.shuffle_train, # False by default + # drop the last batch so all batches have the same num of elements + drop_last=True, + # no need for a distributed sampler, the dataset is already distributed + sampler=None, + ), + **additional_data_loader_kwargs, + ) + + def get_valid_dataloader( + self, dataset: DistributedDataset[TBatch], **additional_data_loader_kwargs: dict[str, Any] + ) -> DataLoader: + """Validation data loader instantiation. + + Args: + dataset: a distributed dataset object. + data_loader_kwargs: additional arguments for the data loader. + + Returns: + the validation data loader. + """ + return self._get_dataloader( + dataset=dataset, + data_loader_kwargs=dict( + pin_memory=True, + shuffle=self.config.train.shuffle_eval, # False by default + drop_last=False, + batch_size=self.config.train.batch_size, + ), + **additional_data_loader_kwargs, + ) + + def get_test_dataloader( + self, dataset: DistributedDataset[TBatch], **additional_data_loader_kwargs: dict[str, Any] + ) -> DataLoader: + """Test data loader instantiation. + + Args: + dataset: a distributed dataset object. + data_loader_kwargs: additional arguments for the data loader. + + Returns: + the test data loader. + """ + return self._get_dataloader( + dataset=dataset, + data_loader_kwargs=dict( + pin_memory=True, + shuffle=False, + batch_size=self.config.train.batch_size, + ), + **additional_data_loader_kwargs, + ) + + def _train( + self, + train_dataset: DistributedDataset[TBatch], + valid_dataset: DistributedDataset[TBatch], + ) -> TModel: + """ + Train a model on a training dataset and occasionally run it on the + validation set. + """ + + if self._is_cuda: + torch.cuda.empty_cache() + + train_conf = self.config.train + + # instantiate train and validation data loaders, currently + # no additional arguments forwarded to instantiation. + train_loader = self.get_train_dataloader(train_dataset) + valid_loader = self.get_valid_dataloader(valid_dataset) + + # calculate the number of data elements this worker should train on + # based on the number of (global) target elements and number of workers + # (cpus/gpus) available. because we use a distributed sampler, the local + # test data loader only sees 1/world_size of the training data already + global_target_elements = train_conf.target_elements + local_target_elements = global_target_elements // self._world_size + + # by setting drop_last=True in the train loader, we make sure all batches have + # the same number of elements + if self.config.train.target_elements_strategy == "batch": + elements_per_batch = self.config.train.batch_size + else: + elements_per_batch = self.config.train.batch_size * self.config.params.input_seq_len + + # because global target elements is a lower bound - we always want to + # train on at least this number of elements - we may train on one more + # batch (due to batch sizes and sequence lengths). in order to reach the + # lower bound, the actual number of elements trained per worker may be + # slightly higher. + local_batch_iters = math.ceil(local_target_elements / elements_per_batch) + expected_local_elements = elements_per_batch * local_batch_iters + expected_global_elements = expected_local_elements * self._world_size + delta = expected_global_elements - global_target_elements + + self._log.debug(f"Global target elements: {global_target_elements}") + self._log.debug( + f"Local target elements: {local_target_elements} ({self._world_size} workers)" + ) + self._log.debug(f"Elements per batch: {elements_per_batch} (bs: {train_conf.batch_size}) ") + self._log.debug( + f"Expected global target elements (w.r.t batch size): {expected_global_elements}" + ) + self._log.debug( + f"Expected local target elements (w.r.t batch size): {expected_local_elements}" + ) + + self._log.debug(f"Target element delta (global): {delta} elements") + self._log.info(f"Running {local_batch_iters} batch iterations") + + optimizer = self.configure_optimizer(self._model_dist.parameters()) + + local_gradient_steps = local_batch_iters // train_conf.gradient_accumulate_every + scheduler = self.configure_scheduler(optimizer, local_gradient_steps) + + epoch = 0 + epoch_batch_idx = 0 + if self.config.resume: + self._log.debug("Resuming training, offsetting start epoch and batch index") + epoch = self.config.resume.next_epoch_index + epoch_batch_idx = self.config.resume.next_batch_index + train_dataset.offset_to(epoch) + else: + self._log.info("Starting training from scratch") + self._log.debug(f"Starting from epoch {epoch}") + self._log.debug(f"Starting from batch {epoch_batch_idx}") + self._log_cuda_memory_snapshot(None) + + cumulative_batch_idx_start = len(train_loader) * epoch + epoch_batch_idx + global_log_train_idxs, global_run_valid_idxs = self._calc_logging_points( + local_batch_iters, + start_batch_idx=cumulative_batch_idx_start, + ) + + def before_new_epoch(epoch: int) -> None: + self._log.info(f"Initializing epoch {epoch}") + train_dataset.offset_to(epoch) + + gradient_scaler = torch.GradScaler(device=self._device_type) + + # total elements seen during trainings + elements_seen_total = 0 + curr_avg_grad: float = -1 + curr_avg_grad_clipped: float = -1 + for iteration in tqdm( + epoch_cycler( + train_loader, + before_new_epoch=before_new_epoch, + start_epoch=epoch, + start_batch=epoch_batch_idx, + max_iters=local_batch_iters, + ), + desc="Training", + mininterval=self.options.train_prog_min_interval_seconds, + disable=not self.options.display_progress, + ): + epoch, epoch_batch_idx, batch = iteration.epoch, iteration.batch, iteration.item + next_epoch, next_batch_idx = iteration.next_epoch, iteration.next_batch + + # keep track of the cumulative batch index across epochs for + # logging. this is used for the log points (when to log train loss + # and run validation). the value is also used to log the global + # batch index for convenient post-processing of the logs + cumulative_batch_idx = len(train_loader) * epoch + epoch_batch_idx + log_prefix = f"{[cumulative_batch_idx]}" + + self._model_dist.train() + + # https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation + with torch.autocast( + device_type=self._device_type, + dtype=self.options.amp_dtype, + ): + start_fw_measure = time() + train_loss = self.model_forward( + # warning - we cast so that the arguments type hints for + # TModel.forward() are preserved, however, because the model + # has been wrapped with DistributedDataParallel, other + # methods might not be available. Use only model.forward() + cast(TModel, self._model_dist), + batch=batch, + device=self._device, + ) + fw_exec_time = time() - start_fw_measure + train_loss_as_flt = float(train_loss.item()) + if math.isnan(train_loss_as_flt): + self._log.error( + f"Invalid loss at batch {epoch_batch_idx}, epoch {epoch}. Train loss (raw): {train_loss}, batch: {batch.shape}" + ) + + # scale the gradient + train_loss = train_loss / train_conf.gradient_accumulate_every + elements_seen_total += elements_per_batch + + scaled_loss = gradient_scaler.scale(train_loss) + + start_bw_measure = time() + scaled_loss.backward() + bw_exec_time = time() - start_bw_measure + + # if enabled, report forward/backward pass execution times for the + # first track_first_fw_bw_exec_times iterations as well as cuda + # memory usage. after a few iterations, we can usually be sure + # memory will not further increase assuming there are no memory + # leaks. skip the first forward/backward pass, which takes much more + # time due to the construction of the computation graph, optimizer + # warmup, etc. + + if self.options.track_first_fw_bw_exec_times: + self.options.track_first_fw_bw_exec_times -= 1 + self._write_csv_timemem( + cum_batch=cumulative_batch_idx, + kind=ModelMode.TRAIN, + fw_time=fw_exec_time, + bw_time=bw_exec_time, + # batch size is constant during training + num_items=self.config.train.batch_size, + ) + + if cumulative_batch_idx in global_log_train_idxs: + self._log.info(f"{log_prefix} Training loss: {train_loss_as_flt}") + self._write_csv_loss( + ModelMode.TRAIN, + loss=train_loss_as_flt, + epoch=epoch, + batch=epoch_batch_idx, + cum_batch=cumulative_batch_idx, + elements_seen=elements_seen_total, + lr=scheduler.get_last_lr()[0], + avg_grad=curr_avg_grad, + avg_grad_clipped=curr_avg_grad_clipped, + ) + + # accumulate the gradient with clipping: + # https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping + if (epoch_batch_idx + 1) % train_conf.gradient_accumulate_every == 0: + # restore the scaled gradient for clipping + gradient_scaler.unscale_(optimizer) + + curr_avg_grad = self.avg_gradient_value() + + if (max_clip := self.config.train.gradient_clipping) is not None: + torch.nn.utils.clip_grad_norm_( + self._model_dist.parameters(), + max_clip, + ) + + curr_avg_grad_clipped = self.avg_gradient_value() + + gradient_scaler.step(optimizer) + scale = gradient_scaler.get_scale() + gradient_scaler.update() + + # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/7 + skip_lr_sched = scale > gradient_scaler.get_scale() + if scheduler and not skip_lr_sched: + scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # before evaluation, we do not perform a gradient update. hence, the + # "elements_seen_total" we log below might be slightly off. + # specifically, this number might be larger than the true number. + # this is because the validation might be performed while gradients + # are still being accumulated, and thus have the model has not + # learned from the elements yet. on a large scale, this hardly + # matters + if not self.options.skip_validation and cumulative_batch_idx in global_run_valid_idxs: + valid_loss = self._evaluate( + self._model_dist, + valid_loader, + items_seen_so_far=elements_seen_total, + cumulative_batch_idx=cumulative_batch_idx, + ) + self._log.info(f"{log_prefix} Validation loss: {valid_loss}") + self._write_csv_loss( + ModelMode.VALID, + loss=valid_loss, + epoch=epoch, + batch=epoch_batch_idx, + cum_batch=cumulative_batch_idx, + elements_seen=elements_seen_total, + lr=-1, + avg_grad=-1, + avg_grad_clipped=-1, + ) + + # after validating, save the state, maybe it's really good! + # before i/o, use a barrier to make sure training states are in + # sync (as seen in https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html ) + dist.barrier() + original_model = self._unpack_distributed_model(self._model_dist) + self._top_n_models.add( + ( + valid_loss, + original_model.state_dict(), + ) + ) + self._save_training_state(next_batch_idx, next_epoch) + + else: + # we have seen exactly local_batch_iters batches + elements_match = elements_seen_total == expected_local_elements + if not elements_match: + self._log.fatal( + f"Mismatch between expected and actual elements seen: {expected_local_elements}, {elements_seen_total}" + ) + self._log.info("Finished training") + self._log.info(f"Stats (local): Elements seen: {elements_seen_total}") + + best_model = self._unpack_distributed_model(self._model_dist) + + if self._is_main_worker: + # if, on the main worker, populate the model with the best state + # non-main workers will simply return the latest model, which won't + # be used anyway because testing happens only on the main worker + ((least_loss, best_state),) = self._top_n_models.get_top(1) + best_model.load_state_dict(best_state) + self._log.info(f"Returning model with least loss ({least_loss})") + + self._log_cuda_memory_snapshot(None) + + return best_model + + def test( + self, + test_dataset: DistributedDataset[TBatch], + model: torch.nn.Module, + ) -> None: + # test only on main worker + if not self._is_main_worker: + return + + # instantiate test data loader, currently + # no additional arguments forwarded to instantiation. + test_loader = self.get_test_dataloader(test_dataset) + self._log.info("Started testing") + model.eval() + test_loss = self._evaluate(model, test_loader, -1, -1) + self._log.info(f"Test loss: {test_loss}") + self._write_csv_loss( + ModelMode.TEST, + loss=test_loss, + elements_seen=-1, + epoch=-1, + batch=-1, + cum_batch=-1, + lr=-1, + avg_grad=-1, + avg_grad_clipped=-1, + ) + self._log.info("Finished testing") + return None diff --git a/src/mblm/trainer/iter.py b/src/mblm/trainer/iter.py new file mode 100644 index 0000000..273404f --- /dev/null +++ b/src/mblm/trainer/iter.py @@ -0,0 +1,117 @@ +__copyright__ = """MIT License + +Copyright (c) 2024 - IBM Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +from dataclasses import dataclass +from typing import Any, Callable, Generator, Generic, Iterator, Protocol, TypeVar + +_T = TypeVar("_T", covariant=True) + + +class _SizedIterable(Protocol[_T]): + """ + An internal protocol type used by the epoch_cycler, typically to allow Torch + DataLoaders as iterables for the cycler. Ideally, we'd have the input + sequence be `Sequence` to facilitate the logic in the epoch cycler. However, + Dataloaders are not `Sequence` because they do not implement `__getitem__` + and thus cannot be used to efficiently index. But since DataLoaders are both + `Sized` and `Iterator`, we can make use that. See + https://docs.python.org/3/library/collections.abc.html#collections-abstract-base-classes + """ + + def __len__(self) -> int: ... + + def __iter__(self) -> Iterator[_T]: ... + + +# TODO: Python 3.12, NamedTuple +@dataclass +class EpochCyclerYield(Generic[_T]): + epoch: int + batch: int + item: _T + next_epoch: int + next_batch: int + + +def epoch_cycler( + seq: _SizedIterable[_T], + *, + before_new_epoch: Callable[[int], Any] | None = None, + start_epoch: int = 0, + start_batch: int = 0, + max_iters: int | None = None, +) -> Generator[EpochCyclerYield[_T], None, None]: + """ + Infintely iterates over a sequence, yielding batches and their indices as + well as the current epoch (one complete iteration over the sequence). + + Args: + seq (_SizedIterable[T]): An object that implements both __size__ and + __len__ holding items T + before_new_epoch (Callable[[int], Any] | None = None): Optional + callback function that is called right **before** the `n`-th starts + with `n`. Can be used to shuffle a dataset or perform arbitry side effects + start_epoch (int = 0): Start/resume from this epoch + start_batch (int = 0): Skip until and yield **starting from** this index + max_iters (int | None = None): Return when this total amount of batches has + been yielded across epochs + + Returns: + Generator (EpochCyclerYield): Yields the epoch, the batch index and the item at + that index. Batch indices are reset at the start of each epoch (that is, + they range from 0 to len(seq)). In order to restore the iterator's state, + the next epoch and next batch index are also returned for convenience. + """ + epoch = start_epoch + global_batch_counter = 0 + if start_batch >= len(seq): + raise IndexError("start_batch is larger than the length of the sequence") + + if before_new_epoch: + before_new_epoch(epoch) + + it = iter(seq) + # advance the first iterator to the start batch index + for _ in range(0, start_batch): + next(it) + + while True: + # do not factor out the call to len - if the sequence changes between + # epochs, we want to account for that. len is O(1), doesn't matter + for i in range(start_batch, len(seq)): + if global_batch_counter == max_iters: + return + item = next(it) + if i == len(seq) - 1: + next_epoch = epoch + 1 + next_batch = 0 + else: + next_epoch = epoch + next_batch = i + 1 + yield EpochCyclerYield[_T](epoch, i, item, next_epoch, next_batch) + global_batch_counter += 1 + start_batch = 0 + epoch += 1 + + if before_new_epoch: + before_new_epoch(epoch) + it = iter(seq) diff --git a/tests/e2e/trainer/__init__.py b/tests/e2e/trainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/e2e/trainer/run_trainer.py b/tests/e2e/trainer/run_trainer.py new file mode 100644 index 0000000..3af58a3 --- /dev/null +++ b/tests/e2e/trainer/run_trainer.py @@ -0,0 +1,184 @@ +""" + +From the project root, run: +make test_e2e +""" + +import math +import os +from pathlib import Path + +import torch +from torch.optim import Adam, Optimizer # type: ignore +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR + +from mblm.data.datasets import DistributedDataset +from mblm.trainer.config import CoreIoConfig, CoreModelParams, CoreTrainConfig, GenericEntryConfig +from mblm.trainer.core import CoreTrainer, CoreTrainerOptions +from mblm.utils.distributed import process_group +from mblm.utils.io import load_yml +from mblm.utils.logging import create_logger +from mblm.utils.misc import count_params +from mblm.utils.seed import seed_everything + +# TODO: Python 3.12, assert_type + + +STORE_N_MODELS = 2 +WRITE_TRAIN_LOSS_TIMES = 20 +WRITE_VALID_LOSS_TIMES = 10 +NUM_TRAIN_ELEMENTS = 10_000 + + +TBatch = tuple[torch.Tensor, torch.Tensor] + + +class ModelParams(CoreModelParams): + hidden_size: int + output_size: int + + +class TrainEntryConfig(GenericEntryConfig[ModelParams, CoreTrainConfig, CoreIoConfig]): + pass + + +class SimpleNN(torch.nn.Module): + def __init__(self, input_size: int, output_size: int, hidden_size: int): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, output_size) + + def forward(self, x: torch.Tensor): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + +class SineDataset(DistributedDataset[TBatch]): + def __init__(self, num_samples: int, worker_id: int, num_workers: int): + self.x = torch.linspace(0, 2 * torch.pi, num_samples).unsqueeze(1) + # complicated dummy function to learn + self.y = ( + torch.sin(2 * torch.pi * self.x) + + torch.cos(4 * torch.pi * self.x) + + 0.5 * self.x**2 + - 0.3 * self.x + ) + super().__init__( + data_size=self.x.numel(), + is_sequential=True, + seq_len=1, + worker_id=worker_id, + num_workers=num_workers, + ) + + def get_sample(self, from_idx): + to_idx = from_idx + self.seq_len + return self.x[from_idx:to_idx], self.y[from_idx:to_idx] + + +class TestTrainer(CoreTrainer[SimpleNN, TBatch, ModelParams, CoreTrainConfig, CoreIoConfig]): + # note regarding type hints: in the real world, the parameters are inferred + # correctly via the generics. however, for the purpose of asserting on the + # output types, we need to annotate them explicitly + + def init_model(self): + # assert_type(output_dir, Path) + # assert_type(is_main_worker, bool) + # assert_type(model_params, ModelParams) + seed_everything(8) + return SimpleNN( + input_size=self.config.params.input_seq_len, + output_size=self.config.params.output_size, + hidden_size=self.config.params.hidden_size, + ) + + def model_forward(self, model, batch, device): + # assert_type(model, SimpleNN) + # assert_type(batch, BatchType) + # assert_type(device, str) + x, y = batch + output = model.forward(x.to(device)) + loss_function = torch.nn.MSELoss() + loss: torch.Tensor = loss_function(output, y) + return loss + + def configure_optimizer(self, parameters) -> Optimizer: + seed_everything(8) + return Adam( + parameters, + lr=self.config.train.learning_rate, + betas=(0.9, 0.95), + ) + + def configure_scheduler(self, optimizer: Optimizer, local_gradient_steps: int) -> LRScheduler: + seed_everything(8) + warmup_steps = math.floor(local_gradient_steps * self.config.train.warmup_steps_perc) + linear = LinearLR( + optimizer, + total_iters=warmup_steps, + start_factor=0.1, + end_factor=1, + ) + cosine_iters = local_gradient_steps - warmup_steps + cosine = CosineAnnealingLR(optimizer, T_max=cosine_iters) + return SequentialLR( + optimizer, + [linear, cosine], + milestones=[warmup_steps], + ) + + def configure_count_parameters(self, model: SimpleNN): + return count_params(model, ("fc_layers", [model.fc1, model.fc2])) + + def configure_run_id(self) -> str: + return os.environ["TEST_ID"] + + +def run(config: TrainEntryConfig) -> None: + log = create_logger(__name__) + log.info("Initiating distributed training") + + try: + with process_group(backend="gloo") as run_vars: + train_dataset = SineDataset( + NUM_TRAIN_ELEMENTS, + worker_id=run_vars.local_rank, + num_workers=run_vars.world_size, + ) + valid_dataset = SineDataset( + 100, + worker_id=0, + num_workers=1, + ) + test_dataset = SineDataset( + 100, + worker_id=0, + num_workers=1, + ) + trainer = TestTrainer( + config, + run_vars=run_vars, + options=CoreTrainerOptions( + # disable tqdm for tests + display_progress=False, + ), + ) + best_model = trainer.train(train_dataset, valid_dataset) + if best_model: + trainer.test(test_dataset, best_model) + except Exception as error: + log.fatal(error) + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("-c", type=Path, required=True, dest="config_path") + args = parser.parse_args() + config_path: Path = args.config_path + cfg = load_yml(config_path, parse_to=TrainEntryConfig) + run(cfg) diff --git a/tests/e2e/trainer/sample-config-0.5-epoch.yml b/tests/e2e/trainer/sample-config-0.5-epoch.yml new file mode 100644 index 0000000..539ac6c --- /dev/null +++ b/tests/e2e/trainer/sample-config-0.5-epoch.yml @@ -0,0 +1,19 @@ +io: + name_model: my-model + output_dir: tests/e2e/trainer/outputs # static + num_models_to_save: 2 + validate_amount: 10 + log_train_loss_amount: 20 +params: + input_seq_len: 1 + hidden_size: 20 + output_size: 1 +train: + # There are 10_000 elements in the training set + target_elements: 5_870 + batch_size: 8 + shuffle_train: true + learning_rate: 0.001 + gradient_clipping: 0.5 + gradient_accumulate_every: 10 + target_elements_strategy: batch diff --git a/tests/e2e/trainer/sample-config-1-epoch.yml b/tests/e2e/trainer/sample-config-1-epoch.yml new file mode 100644 index 0000000..be1bba1 --- /dev/null +++ b/tests/e2e/trainer/sample-config-1-epoch.yml @@ -0,0 +1,18 @@ +io: + name_model: my-model + output_dir: tests/e2e/trainer/outputs # static + num_models_to_save: 2 + validate_amount: 10 + log_train_loss_amount: 20 +params: + input_seq_len: 1 + hidden_size: 20 + output_size: 1 +train: + # There are 10_000 elements in the training set + target_elements: 10_000 + batch_size: 8 + learning_rate: 0.001 + gradient_clipping: 0.5 + gradient_accumulate_every: 10 + target_elements_strategy: batch diff --git a/tests/e2e/trainer/sample-config-1.5-epoch.yml b/tests/e2e/trainer/sample-config-1.5-epoch.yml new file mode 100644 index 0000000..877936a --- /dev/null +++ b/tests/e2e/trainer/sample-config-1.5-epoch.yml @@ -0,0 +1,20 @@ +io: + name_model: my-model + output_dir: tests/e2e/trainer/outputs # static + num_models_to_save: 2 + validate_amount: 10 + log_train_loss_amount: 20 + description: >- + This is config 2 to test gradient accumulation +params: + input_seq_len: 1 + hidden_size: 20 + output_size: 1 +train: + # There are 10_000 elements in the training set + target_elements: 16_611 + batch_size: 8 + learning_rate: 0.001 + gradient_clipping: 0.5 + gradient_accumulate_every: 10 + target_elements_strategy: batch diff --git a/tests/e2e/trainer/sample-config-grad-acc-1.yml b/tests/e2e/trainer/sample-config-grad-acc-1.yml new file mode 100644 index 0000000..cb2470b --- /dev/null +++ b/tests/e2e/trainer/sample-config-grad-acc-1.yml @@ -0,0 +1,21 @@ +io: + name_model: my-model + output_dir: tests/e2e/trainer/outputs # static + num_models_to_save: 2 + validate_amount: 10 + log_train_loss_amount: 20 + description: >- + This is config 1 to test gradient accumulation +params: + input_seq_len: 1 + hidden_size: 20 + output_size: 1 +train: + # There are 10_000 elements in the training set + target_elements: 15_000 + batch_size: 5 + learning_rate: 0.001 + gradient_clipping: 1 + shuffle_train: true + gradient_accumulate_every: 10 + target_elements_strategy: batch diff --git a/tests/e2e/trainer/sample-config-grad-acc-2.yml b/tests/e2e/trainer/sample-config-grad-acc-2.yml new file mode 100644 index 0000000..ce6296f --- /dev/null +++ b/tests/e2e/trainer/sample-config-grad-acc-2.yml @@ -0,0 +1,21 @@ +io: + name_model: my-model + output_dir: tests/e2e/trainer/outputs # static + num_models_to_save: 2 + validate_amount: 10 + log_train_loss_amount: 20 + description: >- + This is config 2 to test gradient accumulation +params: + input_seq_len: 1 + hidden_size: 20 + output_size: 1 +train: + # There are 10_000 elements in the training set + target_elements: 15_000 + batch_size: 10 + learning_rate: 0.001 + gradient_clipping: 1 + shuffle_train: true + gradient_accumulate_every: 5 + target_elements_strategy: batch diff --git a/tests/e2e/trainer/validate.py b/tests/e2e/trainer/validate.py new file mode 100644 index 0000000..9e64896 --- /dev/null +++ b/tests/e2e/trainer/validate.py @@ -0,0 +1,142 @@ +from argparse import ArgumentParser, BooleanOptionalAction +from datetime import datetime +from pathlib import Path + +import polars as pl + +from mblm.trainer.config import CoreIoConfig, CoreModelParams, CoreTrainConfig, GenericOutputConfig +from mblm.utils.io import load_yml + + +class TrainOutputConfig(GenericOutputConfig[CoreModelParams, CoreTrainConfig, CoreIoConfig]): + pass + + +def ensure_no_error_logs(log_file: Path): + with Path.open(log_file) as log: + for line in log.readlines(): + _, _, level, msg = line.split(" - ") + if level == "CRITICAL" or level == "ERROR": + raise AssertionError(msg) + + +def assert_on_model_run_output(output_dir: Path) -> None: + checkpoints = list(output_dir.rglob("*.pth")) + csv_loss_file = output_dir / "loss.csv" + yml_config_file = output_dir / "config.yml" + log_file = output_dir / "train.log" + + assert csv_loss_file.exists(), "Expected a CSV loss file" + assert yml_config_file.exists(), "Expected a YML config file" + assert log_file.exists(), "Expected a log file" + + ensure_no_error_logs(log_file) + + run_config = load_yml(yml_config_file, parse_to=TrainOutputConfig) + + assert ( + len(checkpoints) == run_config.io.num_models_to_save + ), f"Expected {run_config.io.num_models_to_save} model checkpoints" + + # assert that we get the specified number of training loss logs + csv_log = pl.read_csv(csv_loss_file) + + num_train_loss_entries = csv_log.filter(pl.col("kind") == "train").select(pl.len()).item() + assert ( + num_train_loss_entries == run_config.io.log_train_loss_amount + ), f"Expected {run_config.io.log_train_loss_amount} train loss entries, received {num_train_loss_entries}" + + # assert that we get the specified number of validation loss logs + num_valid_loss_entries = csv_log.filter(pl.col("kind") == "valid").select(pl.len()).item() + assert ( + num_valid_loss_entries == run_config.io.validate_amount + ), f"Expected {run_config.io.validate_amount} validation loss entries, received {num_valid_loss_entries}" + + # assert that we the test loss is logged + num_test_loss_entries = csv_log.filter(pl.col("kind") == "test").select(pl.len()).item() + assert num_test_loss_entries == 1 + + # assert we log valid dates in the summary + try: + datetime.fromisoformat(run_config.summary.training_start) + datetime.fromisoformat(run_config.summary.training_end) + except ValueError: + raise AssertionError("Failed to parse training start/end dates") + + +def assert_on_chained_logs(csv_loss_files: list[Path], assert_equal_epochs: bool): + for file_idx in range(0, len(csv_loss_files) - 1): + file_1 = csv_loss_files[file_idx] + file_2 = csv_loss_files[file_idx + 1] + + prev = pl.read_csv(file_1).filter(pl.col("kind") == "train") + curr = pl.read_csv(file_2).filter(pl.col("kind") == "train") + + # every single run was exactly one epoch + if assert_equal_epochs: + # this test also ensures the dfs are the same length + prev.select(pl.col("epoch") + 1).equals( + curr.select(pl.col("epoch")), + ) + + assert prev.select(pl.col("cum_batch").max() + 1).equals( + curr.select(pl.col("cum_batch").min()) + ), "Training did not exactly resume from previous training" + + assert prev.select(pl.col("elements_seen").sum()).equals( + curr.select(pl.col("elements_seen").sum()) + ), "Same configuration saw different number of train items" + + for df in [prev, curr]: + # the chained test is designed to run for around 3 epochs, adjust if necessary + assert df.select(pl.col("epoch").max()).item() < 4, "Expected less than 3 epochs" + + +def assert_on_grad_acc(csv_loss_files: list[Path]): + for file_idx in range(0, len(csv_loss_files) - 1): + file_1 = csv_loss_files[file_idx] + file_2 = csv_loss_files[file_idx + 1] + + prev = pl.read_csv(file_1).sort("timestamp") + curr = pl.read_csv(file_2).sort("timestamp") + + # make sure same loss on test + assert ( + curr.select(pl.col("loss").last().round(3)).item() + == prev.select(pl.col("loss").last().round(3)).item() + ) + + +def run_test(): + parser = ArgumentParser() + parser.add_argument("--check-output", type=Path, dest="check_output") + parser.add_argument("--check-chained-csv", action="append", type=Path, dest="check_chained_csv") + parser.add_argument( + "--check-grad-acc-csv", action="append", type=Path, dest="check_grad_acc_csv" + ) + parser.add_argument("--assert-equal-epochs", action=BooleanOptionalAction, default=False) + args = parser.parse_args() + + check_output: Path | None = args.check_output + check_chained_csv: list[Path] | None = args.check_chained_csv + check_grad_acc_csv: list[Path] | None = args.check_grad_acc_csv + assert_equal_epochs: bool = args.assert_equal_epochs + + if check_output and check_chained_csv and check_grad_acc_csv: + raise AssertionError( + "Speficy either --check-output, --check-chained-log or --check-grad-acc" + ) + if check_output: + print(f"Validating {check_output}") + return assert_on_model_run_output(check_output) + if check_chained_csv: + print(f"Validating {len(check_chained_csv)} chained runs") + return assert_on_chained_logs(check_chained_csv, assert_equal_epochs) + if check_grad_acc_csv: + print("Asserting on gradient accumulation") + return assert_on_grad_acc(check_grad_acc_csv) + raise AssertionError("No tests ran") + + +if __name__ == "__main__": + run_test() diff --git a/tests/fixtures/clevr/images/val/CLEVR_val_000000.png b/tests/fixtures/clevr/images/val/CLEVR_val_000000.png new file mode 100644 index 0000000..709fc86 Binary files /dev/null and b/tests/fixtures/clevr/images/val/CLEVR_val_000000.png differ diff --git a/tests/fixtures/clevr/questions/CLEVR_val_questions.json b/tests/fixtures/clevr/questions/CLEVR_val_questions.json new file mode 100644 index 0000000..3702c4f --- /dev/null +++ b/tests/fixtures/clevr/questions/CLEVR_val_questions.json @@ -0,0 +1,31 @@ +{ + "info": { + "split": "val", + "license": "Creative Commons Attribution (CC BY 4.0)", + "version": "1.0", + "date": "2/14/2017" + }, + "questions": [ + { + "image_index": 0, + "program": [ + { "inputs": [], "function": "scene", "value_inputs": [] }, + { "inputs": [0], "function": "filter_size", "value_inputs": ["large"] }, + { + "inputs": [1], + "function": "filter_material", + "value_inputs": ["metal"] + }, + { "inputs": [2], "function": "unique", "value_inputs": [] }, + { "inputs": [3], "function": "same_shape", "value_inputs": [] }, + { "inputs": [4], "function": "exist", "value_inputs": [] } + ], + "question_index": 0, + "image_filename": "CLEVR_val_000000.png", + "question_family_index": 39, + "split": "val", + "answer": "cylinder", + "question": "The longest question in clevr (across modes) has a byte/ASCII length of 205. This here is a dummy question of the same length to ensure we test clevr with the max length. Filler words: abcdefghijklmnopqrst" + } + ] +} diff --git a/tests/integration/model_config/test_config_to_model.py b/tests/integration/model_config/test_config_to_model.py new file mode 100644 index 0000000..1978465 --- /dev/null +++ b/tests/integration/model_config/test_config_to_model.py @@ -0,0 +1,60 @@ +from pathlib import Path +from typing import Iterable + +import pytest + +from mblm import MBLM, MBLMModelConfig +from mblm.data.dataset.clevr import ClevrOptionalArgs +from mblm.model.mamba import MambaBlockConfig +from mblm.model.transformer import TransformerBlockConfig +from mblm.scripts.train_mblm import TrainEntryConfig +from mblm.utils.io import load_yml + +# TODO +CONFIG_FILES_DIR = "config" +CONFIG_FILES = [(config,) for config in Path(CONFIG_FILES_DIR).glob("*.yaml")] + + +class TestConfigToModel: + def ensure_config_is_valid(self, config: Path): + try: + return load_yml(config, parse_to=TrainEntryConfig) + except Exception: + pytest.fail(f"Invalid config {config}") + + def ensure_dataset_args_are_valid(self, config: TrainEntryConfig) -> None: + if config.io.dataset_id == "clevr": + try: + ClevrOptionalArgs.model_validate(config.io.dataset_args) + except Exception: + pytest.fail(f"Invalid clevr dataset kwargs ({config.io.name_model})") + return None + + def ensure_model_is_created(self, config: TrainEntryConfig) -> None: + for b in config.params.blocks(): + assert isinstance(b, (TransformerBlockConfig, MambaBlockConfig)) + if isinstance(b, TransformerBlockConfig): + assert b.block_type == "transformer" + else: + # mamba1, can be mamba2 (only if tested on Linux with mamba_ssm installed) + assert b.block_type.startswith("mamba") + + _ = MBLM( + MBLMModelConfig( + num_tokens=config.params.num_tokens, + hidden_dims=config.params.hidden_dims, + seq_lens=config.params.seq_lens, + num_layers=config.params.num_layers, + pad_token_id=config.params.pad_token_id, + train_checkpoint_chunks=config.params.train_checkpoint_chunks, + block=config.params.block, + ) + ) + return None + + @pytest.mark.parametrize("config_files", CONFIG_FILES) + def test_config_to_mmb_transformer(self, config_files: Iterable[Path]): + for config_file in config_files: + config = self.ensure_config_is_valid(config_file) + self.ensure_dataset_args_are_valid(config) + self.ensure_model_is_created(config) diff --git a/tests/unit/data/test_bytes.py b/tests/unit/data/test_bytes.py new file mode 100644 index 0000000..bc0bce5 --- /dev/null +++ b/tests/unit/data/test_bytes.py @@ -0,0 +1,36 @@ +from io import BytesIO + +import pytest +import torch + +from mblm.data.utils import Bytes, FileStream + + +class TestByteUtils: + @pytest.mark.parametrize( + "string,expected_bytes", + [ + ["hello", [104, 101, 108, 108, 111]], + ["über", [195, 188, 98, 101, 114]], + ["日本", [230, 151, 165, 230, 156, 172]], + ["🥰", [240, 159, 165, 176]], + ], + ) + def test_string_utils(self, string: str, expected_bytes: list[int]): + b = Bytes.str_to_bytes(string) + assert b == bytes(expected_bytes) == bytearray(expected_bytes) + assert Bytes.bytes_to_str(b) == string + + t = Bytes.str_to_tensor(string) + assert t.equal(torch.tensor(expected_bytes, dtype=torch.uint8)) + assert Bytes.tensor_to_str(t) == string + + +class TestFileStream: + def test_file_stream(self): + data = bytes([1, 2, 3]) + inp_stream = BytesIO() + inp_stream.write(data) + stream = FileStream(inp_stream) + assert stream.to_buffer() == data + assert stream.to_numpy().tolist() == stream.to_tensor().tolist() diff --git a/tests/unit/data/test_clevr.py b/tests/unit/data/test_clevr.py new file mode 100644 index 0000000..006d0f7 --- /dev/null +++ b/tests/unit/data/test_clevr.py @@ -0,0 +1,103 @@ +from pathlib import Path +from typing import cast + +import pydantic +import pytest +import torch + +from mblm.data.dataset.clevr import Clevr, ClevrOptionalArgs +from mblm.data.types import ModelMode +from mblm.data.utils.bytes import Bytes +from mblm.scripts.train_mblm import TrainEntryConfig +from mblm.utils.io import load_yml + +CLEVR_FIXTURE_PATH = Path("tests/fixtures/clevr") + + +def import_experiment_config(exp_config_path: str | Path) -> TrainEntryConfig: + return load_yml(exp_config_path, parse_to=TrainEntryConfig) + + +def import_clevr_from_experiment(exp_config: str | TrainEntryConfig) -> Clevr: + """ + General helper to load the Clevr dataset from a yml experiment config + """ + exp_conf = ( + exp_config + if isinstance(exp_config, TrainEntryConfig) + else import_experiment_config(exp_config) + ) + # adjust the dataset dir to point to the fixture + exp_conf.io.dataset_dir = str(CLEVR_FIXTURE_PATH) + # we can be sure that this is clevr + return cast(Clevr, exp_conf.import_dataset(ModelMode.VALID, worker_id=0, num_workers=1)) + + +def import_clevr_from_fixture(pad_token_id: int, optional_args: ClevrOptionalArgs): + """ + General helper to directly import the Clevr dataset from the fixture + """ + return Clevr( + CLEVR_FIXTURE_PATH, + mode=ModelMode.VALID, + pad_token_id=pad_token_id, + optional_args=optional_args, + seq_len=500_000, + num_workers=1, + worker_id=0, + ) + + +class TestClevrDataset: + PAD_TOKEN_ID = 1001 + EOM_TOKEN_ID = 1002 + SOM_IMG_TOKEN_ID = 1003 + SOM_TXT_TOKEN_ID = 1004 + + def test_clevr_dummy_data(self): + clevr = import_clevr_from_fixture( + pad_token_id=self.PAD_TOKEN_ID, + optional_args=ClevrOptionalArgs(qiqa_loss_mask=(1.0, 1.0, 1.0, 1.0), target_mode="a"), + ) + # our dummy dataset has just 1 sample + assert len(clevr) == 1 + s = clevr.get_sample_raw(0) + question, answer = s["question"], s["answer"] + assert len(Bytes.str_to_tensor(question)) == clevr.MAX_QUESTION_LEN_BYTES + assert len(Bytes.str_to_tensor(answer)) == clevr.MAX_ANSWER_LEN_BYTESN + + def test_sample_modality_indices(self): + clevr = import_clevr_from_fixture( + pad_token_id=self.PAD_TOKEN_ID, + optional_args=ClevrOptionalArgs( + eom_token_id=self.EOM_TOKEN_ID, + som_text_token_id=self.SOM_TXT_TOKEN_ID, + som_image_token_id=self.SOM_IMG_TOKEN_ID, + qiqa_loss_mask=(1.0, 1.0, 1.0, 1.0), + target_mode="a", + ), + ) + sample, _, (question, image, answer, padding) = clevr.get_sample_with_parts(0) + assert sample.dtype == torch.long + + qiqa_reconstructed = torch.concat([question, image, question, answer]) + + assert len(qiqa_reconstructed) + padding == len(sample) + assert question[0].item() == self.SOM_TXT_TOKEN_ID + assert answer[0].item() == self.SOM_TXT_TOKEN_ID + assert image[0].item() == self.SOM_IMG_TOKEN_ID + + for item in (question, image, answer): + assert item[-1].item() == self.EOM_TOKEN_ID + + @pytest.mark.parametrize("quality", [-1, 96]) + def test_clevr_args_validation(self, quality: int): + with pytest.raises(pydantic.ValidationError): + _ = import_clevr_from_fixture( + pad_token_id=self.PAD_TOKEN_ID, + optional_args=ClevrOptionalArgs( + qiqa_loss_mask=(1.0, 1.0, 1.0, 1.0), + target_mode="a", + enable_jpeg_stream_with_quality=quality, + ), + ) diff --git a/tests/unit/data/test_datasets.py b/tests/unit/data/test_datasets.py new file mode 100644 index 0000000..1c19a88 --- /dev/null +++ b/tests/unit/data/test_datasets.py @@ -0,0 +1,222 @@ +from collections import Counter + +import pytest +import torch +from typing_extensions import Unpack + +from mblm.data.datasets import DistributedDataset, DistributedDatasetConfig + + +class MySequentialDataset(DistributedDataset[list[int]]): + def __init__( + self, + data: list[int], + seq_len: int, + *, + worker_id: int, + num_workers: int, + ): + super().__init__( + data_size=len(data), + seq_len=seq_len, + is_sequential=True, + worker_id=worker_id, + num_workers=num_workers, + ) + self._data = data + + def get_sample(self, from_idx): + return self._data[from_idx : from_idx + self.seq_len] + + +class MyDataset(DistributedDataset[int]): + def __init__( + self, + data: list[int], + seq_len: int, + *, + worker_id: int, + num_workers: int, + ): + super().__init__( + data_size=len(data), + seq_len=seq_len, + is_sequential=False, + worker_id=worker_id, + num_workers=num_workers, + ) + self._data = data + + def get_sample(self, from_idx): + return self._data[from_idx] + + +class TestDistributedDataset: + def test_offset_single_worker(self): + data = list(range(14)) + sequence_len = 3 + dataset = MySequentialDataset(data, sequence_len, worker_id=0, num_workers=1) + assert len(dataset) == 4 # (14 // 1) // 3 + + assert dataset[0] == [0, 1, 2] + assert dataset[1] == [3, 4, 5] + assert dataset[2] == [6, 7, 8] + assert dataset[3] == [9, 10, 11] + + dataset.offset_one() + assert len(dataset) == 4 + assert dataset[0] == [1, 2, 3] + assert dataset[1] == [4, 5, 6] + assert dataset[2] == [7, 8, 9] + assert dataset[3] == [10, 11, 12] + + dataset.offset_to(2) + assert len(dataset) == 4 + assert dataset[0] == [2, 3, 4] + assert dataset[1] == [5, 6, 7] + assert dataset[2] == [8, 9, 10] + assert dataset[3] == [11, 12, 13] + + # full cycle + dataset.offset_one() + assert len(dataset) == 4 + assert dataset[0] == [0, 1, 2] + assert dataset[1] == [3, 4, 5] + assert dataset[2] == [6, 7, 8] + assert dataset[3] == [9, 10, 11] + + def test_offset_single_worker_long_seq(self): + # for long sequences, we expect that length decreases over time + data = list(range(13)) + sequence_len = 6 + dataset = MySequentialDataset(data, sequence_len, worker_id=0, num_workers=1) + assert len(dataset) == 2 + assert dataset[0] == [0, 1, 2, 3, 4, 5] + assert dataset[1] == [6, 7, 8, 9, 10, 11] + + dataset.offset_one() + assert len(dataset) == 2 + assert dataset[0] == [1, 2, 3, 4, 5, 6] + assert dataset[1] == [7, 8, 9, 10, 11, 12] + + dataset.offset_to(4) + assert len(dataset) == 1 # only one full sequence can be retrieved now + assert dataset[0] == [4, 5, 6, 7, 8, 9] + + dataset.offset_one() + dataset.offset_one() + assert len(dataset) == 2 + assert dataset[0] == [0, 1, 2, 3, 4, 5] + assert dataset[1] == [6, 7, 8, 9, 10, 11] + + @pytest.mark.parametrize("seq_len,range_end", [(4, 33), (8, 55), (7, 42), (5, 21), (5, 20)]) + def test_offset_two_workers(self, seq_len: int, range_end: int): + num_workers = 2 + data = list(range(0, range_end)) + d1 = MySequentialDataset(data, seq_len, worker_id=0, num_workers=num_workers) + d2 = MySequentialDataset(data, seq_len, worker_id=1, num_workers=num_workers) + + all_items: Counter[int] = Counter() + # test modulo op, cycle offset twice + for _ in range(seq_len * 2): + d1.offset_one() + d2.offset_one() + assert len(d1) == len(d2) + + for seq_idx in range(len(d1) - 1): + assert d1[seq_idx][-1] < d1[seq_idx + 1][0] + assert len(d1[seq_idx]) == seq_len + # add all sequence start elements to the counter + all_items.update([d1[seq_idx][0]]) + all_items.update([d2[seq_idx][0]]) + + # make sure every element has been the first item twice because we've + # cycled the offset twice + assert all([count == 2 for count in all_items.values()]) + + def test_non_sequential_ds(self): + dataset = MyDataset( + data=list(range(21)), + seq_len=-1, # does not matter + worker_id=0, + num_workers=1, + ) + assert len(dataset) == 21 + assert dataset[0] == 0 + assert dataset[1] == 1 + assert dataset[9] == 9 + + dataset.offset_one() # should have no effect but warn + assert dataset[0] == 0 + + dataset = MyDataset( + data=list(range(21)), + seq_len=-1, # does not matter + worker_id=1, + num_workers=2, + ) + assert len(dataset) == 10 + assert dataset[0] == 10 + assert dataset[1] == 11 + assert dataset[9] == 19 + + def test_internal_validation_sequential(self): + # data from 0 to 9, 10 elements + sequence_len = 10 + data = torch.arange(start=0, end=sequence_len) + + class MyDataset(DistributedDataset): + def __init__( + self, + data: torch.Tensor, + is_sequential: bool, + **config: Unpack[DistributedDatasetConfig], + ): + super().__init__( + data.numel(), + is_sequential=is_sequential, + **config, + ) + + def get_sample(self, from_idx): ... + + with pytest.raises(AssertionError) as msg: + MyDataset( + data, + seq_len=sequence_len, + is_sequential=True, + worker_id=1, + num_workers=1, + ) + assert msg.value == "worker_id (2) must be smaller than num_workers (2)" + + with pytest.raises(AssertionError) as msg: + MyDataset( + data, + seq_len=1, + is_sequential=True, + worker_id=0, + num_workers=1, + ) + assert msg.value == "Worker's data is too small" + with pytest.raises(AssertionError) as msg: + MyDataset( + data, + seq_len=1, + is_sequential=True, + worker_id=0, + num_workers=1, + ) + assert msg.value == "Worker's data is too small" + try: + MyDataset( + data, + is_sequential=False, + seq_len=9999, + worker_id=0, + num_workers=1, + ) + except AssertionError: + pytest.fail( + "Sequence length should not be validated for non-sequential dataset", + ) diff --git a/tests/unit/data/test_image.py b/tests/unit/data/test_image.py new file mode 100644 index 0000000..b1b178e --- /dev/null +++ b/tests/unit/data/test_image.py @@ -0,0 +1,110 @@ +import numpy as np +import PIL.Image +import pytest +import torch + +from mblm.data.utils.image import BinMode, ColorSpace, ImagePipeline + +IMAGE_PATH = "tests/fixtures/clevr/images/val/CLEVR_val_000000.png" + +IMAGE_SHAPE_NUMPY_RGB = (12, 15, 3) # H, W, C +IMAGE_SHAPE_TORCH_RGB = (3, 12, 15) # C, H, W +IMAGE_SHAPE_GRAY = (12, 15) # H, W + +TORCH_IMAGE_RGB = torch.randint(0, 256, IMAGE_SHAPE_TORCH_RGB, dtype=torch.uint8) +TORCH_IMAGE_GRAY = torch.randint(0, 256, IMAGE_SHAPE_GRAY, dtype=torch.uint8) + +NUMPY_IMAGE_RGB = np.random.randint(0, 256, IMAGE_SHAPE_NUMPY_RGB, dtype=np.uint8) +NUMPY_IMAGE_GRAY = np.random.randint(0, 256, IMAGE_SHAPE_GRAY, dtype=np.uint8) + + +class TestImagePipeline: + def test_from_path(self): + try: + ImagePipeline(IMAGE_PATH, ColorSpace.RGB) + except Exception: + pytest.fail("Could not create image pipeline from path") + + def test_from_pil(self): + try: + img = PIL.Image.new(ColorSpace.RGB.value, (5, 5)) + ImagePipeline(img, ColorSpace.GRAY) + except Exception: + pytest.fail("Could not create image pipeline from PIL") + + @pytest.mark.parametrize( + "image,cs", + ((TORCH_IMAGE_RGB, ColorSpace.RGB), (TORCH_IMAGE_GRAY, ColorSpace.GRAY)), + ) + def test_from_to_tensor(self, image: torch.Tensor, cs: ColorSpace): + assert image.equal(ImagePipeline(image, cs).to_tensor()) + + @pytest.mark.parametrize( + "image,cs", + ((NUMPY_IMAGE_RGB, ColorSpace.RGB), (NUMPY_IMAGE_GRAY, ColorSpace.GRAY)), + ) + def test_from_to_numpy(self, image: np.ndarray, cs: ColorSpace): + assert np.array_equal(image, ImagePipeline(image, cs).to_numpy()) + + @pytest.mark.parametrize("image", (NUMPY_IMAGE_RGB.astype(np.long), TORCH_IMAGE_RGB.long())) + def test_from_wrong_dtype(self, image: torch.Tensor): + with pytest.raises(ValueError): + ImagePipeline(image, ColorSpace.RGB) + + @pytest.mark.parametrize( + "image,cs", + ((NUMPY_IMAGE_RGB, ColorSpace.RGB), (NUMPY_IMAGE_GRAY, ColorSpace.GRAY)), + ) + def test_to_grayscale(self, image: np.ndarray, cs: ColorSpace): + # for already gray images, nothing happens + assert ImagePipeline(image, cs).grayscale().to_numpy().shape == IMAGE_SHAPE_GRAY + + def test_resize(self): + new_w, new_h = 10, 15 + resized = ImagePipeline(NUMPY_IMAGE_RGB, ColorSpace.RGB).resize((new_w, new_h)).to_image() + assert (resized.width, resized.height) == (new_w, new_h) + + def test_crop(self): + crop_h_w_perc = 0.25, 0.20 + cropped = ImagePipeline(NUMPY_IMAGE_RGB, ColorSpace.RGB).crop((crop_h_w_perc)).to_image() + # adjust if image constants change + assert (cropped.height, cropped.width) == (6, 9) + + @pytest.mark.parametrize( + "image,cs", + ((NUMPY_IMAGE_RGB, ColorSpace.RGB), (NUMPY_IMAGE_GRAY, ColorSpace.GRAY)), + ) + def test_compress_jpg(self, image: np.ndarray, cs: ColorSpace): + buffer = ImagePipeline(image, cs).to_jpeg_buffer(5) + jpeg_magic_number = b"\xff\xd8\xff" + # make sure is a jpeg image now + assert buffer.to_buffer()[:3] == jpeg_magic_number + + @pytest.mark.parametrize( + "bin_mode,expected_bin_vals", + (("lower", (0, 127)), ("upper", (127, 255)), ("mean", (63, 191))), + ) + def test_downsample_two_bins(self, bin_mode: BinMode, expected_bin_vals: tuple[int, int]): + num_bins = 2 + output_image = ( + ImagePipeline(NUMPY_IMAGE_RGB, ColorSpace.RGB) + .downsample_channels( + num_bins, + bin_mode=bin_mode, + ) + .to_numpy() + ) + expected_values = np.array(expected_bin_vals, dtype=np.uint8) + received_values, _ = np.unique_counts(output_image) + assert np.array_equal(expected_values, received_values) + + @pytest.mark.parametrize("num_bins,abs_tol", ((254, 1), (255, 0), (256, 0))) + def test_downsample_boundaries(self, num_bins: int, abs_tol: int): + pipeline = ImagePipeline(NUMPY_IMAGE_RGB, ColorSpace.RGB) + original_image = pipeline.to_numpy() + output_image = pipeline.downsample_channels(num_bins).to_numpy() + assert np.isclose(original_image, output_image, atol=abs_tol, rtol=0).all() + + def test_downsample_error(self): + with pytest.raises(TypeError): + ImagePipeline(NUMPY_IMAGE_RGB, ColorSpace.RGB).downsample_channels(0) diff --git a/tests/unit/data/test_tokenizer.py b/tests/unit/data/test_tokenizer.py new file mode 100644 index 0000000..c83973f --- /dev/null +++ b/tests/unit/data/test_tokenizer.py @@ -0,0 +1,78 @@ +import pytest +import torch + +from mblm.data.utils import Tokenizer +from mblm.data.utils.tokenizer import TokenizerOptions + + +class TestTokenizer: + def test_pipeline(self): + pipeline = Tokenizer( + TokenizerOptions( + pad_token_id=10, + eom_token_id=11, + som_image_token_id=12, + som_text_token_id=13, + ) + ).pipeline + inp = torch.arange(0, 10, dtype=torch.uint8) # length 10 + out = pipeline(inp).with_eom().with_som_text().pad_right_to(15).to_long_tensor() + + assert out.dtype == torch.long + assert out.size(0) == 15 + assert out[0].item() == 13 # som token + assert out[11].item() == 11 # eom token + assert out[12:].equal(torch.tensor([10, 10, 10])) # padding token + + def test_pipeline_none(self): + pipeline = Tokenizer( + TokenizerOptions( + pad_token_id=10, + eom_token_id=None, + som_image_token_id=None, + som_text_token_id=None, + ) + ).pipeline + inp = torch.arange(0, 10, dtype=torch.uint8) # length 10 + + assert pipeline(inp).to_long_tensor().dtype == torch.long + assert pipeline(inp).pad_right_to(15).to_long_tensor().dtype == torch.long + + out = pipeline(inp).with_eom().with_som_text().pad_right_to(15).to_long_tensor() + + assert out.dtype == torch.long + assert out.size(0) == 15 + assert out[:10].equal(inp) # no change + assert out[10:].equal(torch.tensor([10, 10, 10, 10, 10])) # rest is padding + + def test_pipeline_pad_err(self): + pipeline = Tokenizer( + TokenizerOptions( + pad_token_id=10, + eom_token_id=None, + som_image_token_id=None, + som_text_token_id=None, + ) + ).pipeline + inp = torch.arange(0, 10) + + with pytest.raises(ValueError) as exc_info: + pipeline(inp).pad_right_to(9).to_long_tensor() + exp_error = "Tensor at dim 0 (length 10) larger than desired padded size 9" + assert exp_error in str(exc_info.value) + + def test_pipeline_2d_err(self): + pipeline = Tokenizer( + TokenizerOptions( + pad_token_id=10, + eom_token_id=None, + som_image_token_id=None, + som_text_token_id=None, + ) + ).pipeline + inp = torch.randn((2, 1)) + + with pytest.raises(ValueError) as exc_info: + pipeline(inp).to_long_tensor() + exp_error = "Can only process 1D tensors, input is 2D" + assert exp_error in str(exc_info.value) diff --git a/tests/unit/data/test_utils.py b/tests/unit/data/test_utils.py new file mode 100644 index 0000000..58507cf --- /dev/null +++ b/tests/unit/data/test_utils.py @@ -0,0 +1,42 @@ +import torch + +from mblm.data.utils import target_loss_mask +from mblm.data.utils.misc import shift_remap_tensor + + +def test_loss_mask(): + inp = torch.ones((3), dtype=torch.long) # special dtype + loss_mask = target_loss_mask([(inp, 2), (inp, 1), (inp, 0.1)]) + + assert len(loss_mask) == 3 * len(inp) + assert loss_mask.dtype == torch.float # not long + + mask_2, mask_1, mask_0_1 = torch.chunk(loss_mask, 3) + + assert (mask_2 == 2).all().item() + assert (mask_1 == 1).all().item() + assert (mask_0_1 == 0.1).all().item() + + +def test_shift_unshift(): + input_tensor = torch.tensor( + [ + [2, 3, 3], + [1, 2, 3], + [3, 5, 1], + ], + dtype=torch.uint8, + ) + shifted, unshift, indices = shift_remap_tensor(input_tensor, range_start=11) + expected_shifted = 11 + torch.tensor( + [ + [1, 2, 2], + [0, 1, 2], + [2, 3, 0], + ] + ) + + assert shifted.equal(expected_shifted) + assert shifted.shape == input_tensor.shape + assert shifted.dtype == input_tensor.dtype + assert input_tensor.equal(unshift[indices]) diff --git a/tests/unit/model/test_mblm.py b/tests/unit/model/test_mblm.py new file mode 100644 index 0000000..dffa3a6 --- /dev/null +++ b/tests/unit/model/test_mblm.py @@ -0,0 +1,63 @@ +import pytest +import torch + +from mblm import MBLM, MBLMModelConfig, MBLMReturnType +from mblm.model.transformer import TransformerBlockConfig + + +class TestMBLM: + num_tokens = 256 + 1 + pad_token_id = 256 + num_attn_heads = 16 + dim_attn_heads = 64 + ff_mult = 4 + dropout = 0 + use_rot_emb = True + use_flash_attn = False + model_fixtures_dims_lens: list[tuple[tuple[int, ...], tuple[int, ...]]] = [ + ((1024, 768, 512), (9, 7, 5)), + ((1024, 1024), (9, 7)), + ((1024,), (9,)), + ] + + @pytest.mark.parametrize("model_dims,seq_lens", model_fixtures_dims_lens) + def test_masked_loss( + self, + model_dims: tuple[int, ...], + seq_lens: tuple[int, ...], + ): + mmb = MBLM( + MBLMModelConfig( + num_tokens=self.num_tokens, + hidden_dims=model_dims, + seq_lens=seq_lens, + pad_token_id=self.pad_token_id, + num_layers=(1,) * len(model_dims), + train_checkpoint_chunks=None, + block=TransformerBlockConfig( + attn_head_dims=self.dim_attn_heads, + attn_num_heads=self.num_attn_heads, + attn_dropout=self.dropout, + ff_multiplier=self.ff_mult, + ff_dropout=self.dropout, + patch_pos_emb_type="fixed", + attn_use_rot_embs=self.use_rot_emb, + use_flash_attn=self.use_flash_attn, + ), + ) + ) + input_len = 9 + input_ids = torch.randint(0, self.num_tokens, size=(1, input_len), dtype=torch.long) + loss = mmb.forward(input_ids, return_type=MBLMReturnType.LOSS) + loss_with_identity_mask = mmb.forward( + input_ids, + loss_mask=torch.ones_like(input_ids), + return_type=MBLMReturnType.LOSS, + ) + assert torch.equal(loss, loss_with_identity_mask) + empty_loss = mmb.forward( + input_ids, + loss_mask=torch.zeros_like(input_ids), + return_type=MBLMReturnType.LOSS, + ) + assert empty_loss.item() == 0.0 diff --git a/tests/unit/model/test_megabyte_diff.py b/tests/unit/model/test_megabyte_diff.py new file mode 100644 index 0000000..c14cdaf --- /dev/null +++ b/tests/unit/model/test_megabyte_diff.py @@ -0,0 +1,128 @@ +import math + +import pytest +import torch +from MEGABYTE_pytorch import MEGABYTE + +from mblm import MBLM, MBLMModelConfig, MBLMReturnType +from mblm.model.transformer import TransformerBlockConfig +from mblm.utils.seed import seed_everything + + +class TestMegabyte: + num_tokens = 256 + 1 # for padding id + pad_token_id = 0 + num_attn_heads = 16 + dim_attn_heads = 64 + ff_mult = 4 + dropout = 0 + use_rot_emb = True + use_flash_attn = False + + model_fixtures_dims_lens: list[tuple[tuple[int, ...], tuple[int, ...]]] = [ + ((1024, 768, 512), (9, 7, 5)), + ((1024, 1024), (9, 7)), + ((1024,), (9,)), + ] + + def boostrap_models( + self, + model_dims: tuple[int, ...], + seq_lens: tuple[int, ...], + use_fixed_patch_use_encoding: bool = False, + ): + num_layers = (1,) * len(model_dims) + + seed_everything(8) + original_megabyte = MEGABYTE( + num_tokens=self.num_tokens, + pad_id=self.pad_token_id, + dim=model_dims, + max_seq_len=seq_lens, + depth=num_layers, + dim_head=self.dim_attn_heads, + heads=self.num_attn_heads, + attn_dropout=self.dropout, + ff_mult=self.ff_mult, + ff_dropout=self.dropout, + rel_pos=self.use_rot_emb, + pos_emb=use_fixed_patch_use_encoding, + flash_attn=self.use_flash_attn, + ) + # generating random numbers changes the stage of the random number + # generator + # - seed again + seed_everything(8) + patched_megabyte = MBLM( + MBLMModelConfig( + num_tokens=self.num_tokens, + hidden_dims=model_dims, + seq_lens=seq_lens, + pad_token_id=self.pad_token_id, + num_layers=num_layers, + train_checkpoint_chunks=None, + block=TransformerBlockConfig( + attn_head_dims=self.dim_attn_heads, + attn_num_heads=self.num_attn_heads, + attn_dropout=self.dropout, + ff_multiplier=self.ff_mult, + ff_dropout=self.dropout, + patch_pos_emb_type="fixed" if use_fixed_patch_use_encoding else None, + attn_use_rot_embs=self.use_rot_emb, + use_flash_attn=self.use_flash_attn, + ), + ) + ) + return original_megabyte, patched_megabyte + + def make_input_tensor(self, seq_len: int): + return torch.randint(1, self.num_tokens, size=(1, seq_len), dtype=torch.long) + + @pytest.mark.parametrize("model_dims,seq_lens", model_fixtures_dims_lens) + def test_megabyte_models_equal( + self, + model_dims: tuple[int, ...], + seq_lens: tuple[int, ...], + ): + original, patched = self.boostrap_models(model_dims, seq_lens) + + input_tensor = self.make_input_tensor(9) + loss_original = original.forward(input_tensor, return_loss=True) + loss_patched = patched.forward(input_tensor, return_type=MBLMReturnType.LOSS) + assert loss_original.isclose( + loss_patched, atol=0.0001 + ), f"losses ({loss_original}, {loss_patched}) do not match - model dim {model_dims}" + assert ( + loss_original.dtype == loss_patched.dtype + ), "loss dtypes do not match - model dim {model_dims}" + + loss_original.backward() + loss_patched.backward() + + logits_original = original.forward(input_tensor, return_loss=False) + logits_patched = patched.forward(input_tensor, return_type=MBLMReturnType.LOGITS) + assert logits_original.equal(logits_patched), f"preds do not match - model dim {model_dims}" + assert ( + logits_original.dtype == loss_patched.dtype + ), "pred dtypes do not match - model dim {model_dims}" + + @pytest.mark.parametrize("model_dims,seq_lens", model_fixtures_dims_lens) + def test_transformer_seq_len_overflow( + self, + model_dims: tuple[int, ...], + seq_lens: tuple[int, ...], + ): + original, patched = self.boostrap_models( + model_dims, seq_lens, use_fixed_patch_use_encoding=True + ) + input_seq_len = math.prod(seq_lens) # max sequence length possible + input_tensor = self.make_input_tensor(input_seq_len) + input_tensor_too_large = self.make_input_tensor(input_seq_len + 1) + + original.forward(input_tensor) # should pass + patched.forward(input_tensor) # should pass + + with pytest.raises(AssertionError): + original.forward(input_tensor_too_large) + with pytest.raises(AssertionError): + patched.forward(input_tensor_too_large) diff --git a/tests/unit/trainer/test_iter.py b/tests/unit/trainer/test_iter.py new file mode 100644 index 0000000..3e4296e --- /dev/null +++ b/tests/unit/trainer/test_iter.py @@ -0,0 +1,154 @@ +import random + +import pytest +from pytest_mock import MockerFixture + +from mblm.trainer.iter import epoch_cycler + + +class TestEpochCycler: + def test_callback(self, mocker: MockerFixture): + # kind of a verbose test overall but readable + data: list[int] = [-1] * 5 + + new_epoch_stub = mocker.stub("before_new_epoch") + + cycler = epoch_cycler( + data, + before_new_epoch=new_epoch_stub, + ) + + assert all([v == -1 for v in data]) + + # ------ epoch 0 + n = next(cycler) + assert n.epoch == 0 and n.batch == 0 + assert n.next_epoch == 0 and n.next_batch == 1 + new_epoch_stub.assert_called_with(0) + + n = next(cycler) + assert n.epoch == 0 and n.batch == 1 + + n = next(cycler) + assert n.epoch == 0 and n.batch == 2 + + n = next(cycler) + assert n.epoch == 0 and n.batch == 3 + + # last iteration of of epoch 0 + n = next(cycler) + assert n.epoch == 0 and n.batch == 4 + assert n.next_epoch == 1 and n.next_batch == 0 # peek into the future + + # ------ epoch 1 + n = next(cycler) + assert n.epoch == 1 and n.batch == 0 + assert n.next_epoch == 1 and n.next_batch == 1 + new_epoch_stub.assert_called_with(1) + + n = next(cycler) + assert n.epoch == 1 and n.batch == 1 + + n = next(cycler) + assert n.epoch == 1 and n.batch == 2 + + n = next(cycler) + assert n.epoch == 1 and n.batch == 3 + + n = next(cycler) + assert n.epoch == 1 and n.batch == 4 + assert n.next_epoch == 2 and n.next_batch == 0 # peek into the future + + # ------ epoch 2 + n = next(cycler) + assert n.epoch == 2 and n.batch == 0 + assert n.next_epoch == 2 and n.next_batch == 1 + new_epoch_stub.assert_called_with(2) + + assert new_epoch_stub.call_count == 3 + + @pytest.mark.parametrize("exp_start_epoch,exp_start_batch", [(1, 3), (2, 3), (0, 1)]) + def test_resume_middle(self, exp_start_epoch: int, exp_start_batch: int): + epoch_len = 5 + + data = list(range(epoch_len)) + cycler = epoch_cycler( + data, + start_batch=exp_start_batch, + start_epoch=exp_start_epoch, + ) + + # a single epoch has length 5 + n = next(cycler) + epoch, batch, item = n.epoch, n.batch, n.item + + # make sure we start at an arbitrary epoch with a start index + assert epoch == exp_start_epoch + assert batch == exp_start_batch + assert item == data[exp_start_batch] + + # skip through + while epoch == exp_start_epoch: + n = next(cycler) + epoch, batch, item = n.epoch, n.batch, n.item + + # after an epoch has passed, resume from beginning of sequence + assert epoch == exp_start_epoch + 1 + assert batch == 0 + assert item == data[0] + + @pytest.mark.parametrize("max_iters", [0, 1, 2, 3, 4, 5, 6, 7]) + def test_max_iters(self, max_iters: int): + epoch_len = 5 + data = list(range(epoch_len)) + # regardless of the start, we aim for a number of target batches + start_epoch = random.choice(data) + start_batch = random.choice(data) + + cycler = epoch_cycler( + data, + max_iters=max_iters, + start_epoch=start_epoch, + start_batch=start_batch, + ) + num_iters = 0 + for _ in cycler: + num_iters += 1 + + assert num_iters == max_iters + + def test_changing_epoch_len(self): + epoch_len = 2 + data = list(range(epoch_len)) + cycler = epoch_cycler(data) + # ------ epoch 0 + # cycle to the end of epoch, 2 iterations + n = next(cycler) # 1 + n = next(cycler) # 2 + assert n.epoch == 0 and n.batch == 1 + assert n.next_epoch == 1 and n.next_batch == 0 + + # ------ epoch 1 + # mutate in place -> the next epoch should only have 1 iteration! + data.pop() + n = next(cycler) # should reach end of epoch + assert n.epoch == 1 and n.batch == 0 + assert n.next_epoch == 2 + + # ------ epoch 2 + # mutate in place -> the next epoch should have 3 iterations! + data.extend([1, 2]) + assert data == [0, 1, 2] + n = next(cycler) # 1 + n = next(cycler) # 2 + n = next(cycler) # 3 + assert n.epoch == 2 and n.batch == 2 + assert n.next_epoch == 3 and n.next_batch == 0 + + def test_resume_raises(self): + with pytest.raises(IndexError): + cycler = epoch_cycler( + seq=range(2), # 0, 1 + start_batch=2, + ) + next(cycler) diff --git a/tests/unit/utils/test_io.py b/tests/unit/utils/test_io.py new file mode 100644 index 0000000..6035609 --- /dev/null +++ b/tests/unit/utils/test_io.py @@ -0,0 +1,327 @@ +import csv +import tempfile +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from pathlib import Path +from typing import NamedTuple, cast + +import pytest +import torch +from pydantic import BaseModel + +from mblm import MBLM, MBLMModelConfig +from mblm.model.embeddings import MBLM_TOKEN_EMB_MIGRATION +from mblm.model.transformer import TransformerBlockConfig +from mblm.utils.io import ( + CSVWriter, + NDJSONWriter, + dump_yml, + load_model_state, + load_yml, + read_jsonl, + save_model_state, +) + +# TODO: Python 3.12, assert_type + + +class TestYMLUtils: + def test_cfrom_yml(self): + class Klass(BaseModel): + num: int + + with tempfile.TemporaryDirectory() as temp_dir: + kls = Klass(num=5) + dumped_to = dump_yml(Path(temp_dir) / "file", kls) + restored = load_yml(dumped_to, Klass) + # assert_type(restored, Klass) + assert isinstance(restored, Klass) + + +class DummyCSVEntry(NamedTuple): + kind: str + idx: int + time: str + + +class TestCSVWriter: + def test_parallel(self): + """Simulate parallel writes to the same file with index and timestamp.""" + + # Verify the file content + with tempfile.TemporaryDirectory() as tmpdir: + temp_output_dir = Path(tmpdir) + csv_writer = CSVWriter[DummyCSVEntry](output_dir=temp_output_dir, file_name="test") + + def write_row(index): + row = DummyCSVEntry( + kind="test", + idx=index, + time=datetime.now().isoformat(), + ) + csv_writer.write_row(row) + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(write_row, i) for i in range(10)] + for future in futures: + future.result() + csv_file = temp_output_dir / "test.csv" + with csv_file.open("r", encoding="utf-8") as f: + reader = list(csv.reader(f)) + assert reader[0] == list(DummyCSVEntry._fields) + assert len(reader) == 11 + + indexes = [] + for row in reader[1:]: + assert row[0] == "test" + assert len(row[2]) > 0 + indexes.append(row[1]) + # Order may differ due to concurrent writing + assert list(map(str, range(10))) == sorted(indexes) + + +class TestModelCheckpointing: + class Model(torch.nn.Module): # noqa + def __init__(self, num_embs: int, emb_dim: int): + super().__init__() + self.emb = torch.nn.Embedding(num_embs, emb_dim) + self.seq = torch.nn.Sequential( + torch.nn.Embedding(num_embs, emb_dim), + torch.nn.Linear(emb_dim, num_embs), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _ = self.emb(x) + return self.seq(x) + + class OldModel(torch.nn.Module): # noqa + def __init__(self): + super().__init__() + self.keep = torch.nn.Parameter(torch.randn((1, 1))) + self.old_param = torch.nn.Parameter(torch.ones((1, 1))) + self.old_param_x = torch.nn.Parameter(torch.ones((1, 1))) + + class NewModel(torch.nn.Module): # noqa + def __init__(self): + super().__init__() + self.keep = torch.nn.Parameter(torch.randn((1, 1))) + self.new_param = torch.nn.Parameter(torch.zeros((1, 1))) + self.new_param_x = torch.nn.Parameter(torch.zeros((1, 1))) + + @pytest.mark.parametrize( + "src_emb_size,tgt_emb_size", + [ + (4, 4), # same size + (2, 3), # slightly larger + (4, 10), # much larger + ], + ) + @torch.no_grad() + def test_load_map_state(self, src_emb_size: int, tgt_emb_size: int): + assert src_emb_size <= tgt_emb_size, "Invalid test" + + model_src = self.Model(src_emb_size, 3) + model_tgt = self.Model(tgt_emb_size, 3) + with tempfile.TemporaryDirectory() as tmpdir: + _, chkpoint = save_model_state(tmpdir, "checkpoint", model_src, 0) + model_tgt, _ = load_model_state( + chkpoint, + model_tgt, + map_extend_embeddings={ + "emb.weight", + "seq.0.weight", + "seq.1.weight", + "seq.1.bias", + }, + ) + assert model_tgt.emb.weight.size(0) == tgt_emb_size + assert model_tgt.emb.weight[:src_emb_size].equal(model_src.emb.weight) + assert model_tgt.seq[1].weight[:src_emb_size].equal(model_src.seq[1].weight) + assert model_tgt.seq[1].bias[:src_emb_size].equal(model_src.seq[1].bias) + + # make sure that the first part of the logits is equal + max_token_id = src_emb_size - 1 + input_both = torch.tensor([max_token_id]).long() + src_logits = model_src.forward(input_both) + tgt_logits = model_tgt.forward(input_both) + + assert tgt_logits[:, :src_emb_size].equal(src_logits) + + @torch.no_grad() + def test_load_map_state_mmb(self): + def create_model(num_tokens: int): + return MBLM( + MBLMModelConfig( + num_tokens=num_tokens, + pad_token_id=0, + hidden_dims=(1024, 512), + num_layers=(1, 1), + seq_lens=(8192, 8), + train_checkpoint_chunks=None, + block=TransformerBlockConfig( + attn_head_dims=64, + attn_num_heads=8, + attn_use_rot_embs=True, + patch_pos_emb_type=None, + ), + ) + ) + + num_src_emb, num_tgt_emb = 5, 6 + model_src = create_model(num_src_emb) + model_tgt = create_model(num_tgt_emb) + with tempfile.TemporaryDirectory() as tmpdir: + _, chkpoint = save_model_state(tmpdir, "checkpoint", model_src, 0) + model_tgt, _ = load_model_state( + chkpoint, + model_tgt, + map_extend_embeddings=MBLM_TOKEN_EMB_MIGRATION, + ) + """ + This is the structure of the embeddings we're migrating: + + (token_embs_rev): ModuleList( + case 1: (0): Embedding(255, 512, padding_idx=0) + (1): Sequential( + case 2: (0): Embedding(255, 512, padding_idx=0) + (1): Rearrange('... r d -> ... (r d)') + (2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) + (3): Linear(in_features=4096, out_features=1024, bias=True) + (4): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + ) + ... + + cases 3/4: (to_logits): Linear(in_features=512, out_features=255, bias=True) + ) + """ + src_emb = cast(torch.nn.Embedding, model_src.token_embs_rev[0]) + src_emb_seq = cast(torch.nn.Sequential, model_src.token_embs_rev[1]) + tgt_emb = cast(torch.nn.Embedding, model_tgt.token_embs_rev[0]) + tgt_emb_seq = cast(torch.nn.Sequential, model_tgt.token_embs_rev[1]) + + # case 1, base embedding + assert tgt_emb.num_embeddings == num_tgt_emb + assert tgt_emb.weight[:num_src_emb].equal(src_emb.weight) + # case 2, embedding in sequential + assert tgt_emb_seq[0].num_embeddings == num_tgt_emb + assert tgt_emb_seq[0].weight[:num_src_emb].equal(src_emb_seq[0].weight) + # case 3/4, logits + assert model_tgt.to_logits.weight.size(0) == num_tgt_emb + assert model_tgt.to_logits.bias.size(0) == num_tgt_emb + assert model_tgt.to_logits.weight[:num_src_emb].equal(model_src.to_logits.weight) + assert model_tgt.to_logits.bias[:num_src_emb].equal(model_src.to_logits.bias) + + # check if new token id works + max_new_token_id = num_tgt_emb - 1 + input_for_tgt_model_only = torch.tensor([[max_new_token_id]]).long() + with pytest.raises(Exception): + # should fail for old model + model_src.forward(input_for_tgt_model_only) + try: + # should work for migrated model + model_tgt.forward(input_for_tgt_model_only) + except Exception as error: + pytest.fail(f"Forward pass should work: {error}") + + @pytest.mark.parametrize("rename_from,rename_to", [("old_param", "new_param")]) + def test_load_map_state_rename(self, rename_from: str, rename_to: str): + new_model = self.NewModel() + old_model = self.OldModel() + assert not new_model.new_param.all() # before migration + assert old_model.old_param.all() + + with tempfile.TemporaryDirectory() as tmpdir: + _, chkpoint = save_model_state(tmpdir, "checkpoint", old_model, 0) + + new_model, _ = load_model_state( + chkpoint, + new_model, + map_rename_modules=((rename_from, rename_to),), + ) + + assert new_model.new_param.equal(old_model.old_param) + # make sure all modules with prefix are renamed + assert new_model.new_param_x.equal(old_model.old_param_x) + + @pytest.mark.parametrize( + "src_emb_size,tgt_emb_size,src_emb_dim,tgt_emb_dim,err_msg", + [ + (3, 2, 4, 4, "Mapping to a smaller number of embeddings"), # shrinking + (3, 3, 4, 5, "Mapping to a smaller embedding dimension"), # incompatible emb dim + ], + ) + def test_load_map_state_unsupported_mapping( + self, + src_emb_size: int, + tgt_emb_size: int, + src_emb_dim: int, + tgt_emb_dim: int, + err_msg: str, + ): + src_mod = self.Model(src_emb_size, src_emb_dim) + tgt_mod = self.Model(tgt_emb_size, tgt_emb_dim) + with tempfile.TemporaryDirectory() as tmpdir: + _, chkpoint = save_model_state(tmpdir, "checkpoint", src_mod, 0) + with pytest.raises(ValueError) as exc_info: + load_model_state( + chkpoint, + tgt_mod, + map_extend_embeddings={ + "emb.weight", + "seq.0.weight", + "seq.1.weight", + "seq.1.bias", + }, + ) + + assert err_msg in str(exc_info.value) + + def test_load_map_state_unsupported_different_modules(self): + src_mod = self.OldModel() + tgt_mod = self.NewModel() + with tempfile.TemporaryDirectory() as tmpdir: + _, chkpoint = save_model_state(tmpdir, "checkpoint", src_mod, 0) + with pytest.raises(ValueError) as exc_info: + load_model_state( + chkpoint, + tgt_mod, + map_extend_embeddings={"keep"}, + ) + + assert "Expected source and target state dict to match" in str(exc_info.value) + + +class TestNDJSONWriter: + class MyClass(BaseModel): # noqa: D106 + data: str + + first_entry = MyClass(data="a") + temp_entry = MyClass(data="bbbbbbb") # long entry + second_entry = MyClass(data="c") + + def test_write_and_remove(self): + with tempfile.TemporaryDirectory() as tmpdir: + file = Path(tmpdir) / "file.jsonl" + writer = NDJSONWriter[TestNDJSONWriter.MyClass](file) + writer.write_line(self.first_entry) + writer.write_line(self.temp_entry) + writer.remove_last_line() + writer.write_line(self.second_entry) + + result = read_jsonl(file, parse_lines_to=self.MyClass) + assert len(result) == 2 + assert result[0] == self.first_entry + assert result[1] == self.second_entry + + def test_write_and_remove_multiple(self): + with tempfile.TemporaryDirectory() as tmpdir: + file = Path(tmpdir) / "file.jsonl" + writer = NDJSONWriter[TestNDJSONWriter.MyClass](file) + writer.write_line(self.first_entry) + writer.write_line(self.temp_entry) + writer.remove_last_line() + writer.remove_last_line() + writer.remove_last_line() + + result = read_jsonl(file, parse_lines_to=self.MyClass) + assert len(result) == 0 diff --git a/tests/unit/utils/test_retry.py b/tests/unit/utils/test_retry.py new file mode 100644 index 0000000..41b77a2 --- /dev/null +++ b/tests/unit/utils/test_retry.py @@ -0,0 +1,53 @@ +from typing import Literal + +import pytest +from pytest_mock import MockerFixture + +from mblm.utils.retry import retry + + +class FailThenSuccess: + def __init__(self, num_fails: int): + self.max_num_fails = num_fails + self.has_failed_times = 0 + + def run(self): + if self.has_failed_times < self.max_num_fails: + self.has_failed_times += 1 + raise Exception("error") + return True + + +class TestTrainerUtils: + @pytest.mark.parametrize( + "n_retries,n_inner_fails,expected_calls,expected_result", + [ + [0, 0, 1, True], + [3, 0, 1, True], + [1, 0, 1, True], + [0, 1, 1, None], + [1, 1, 2, True], + [4, 1, 2, True], + [0, 2, 1, None], + [2, 2, 3, True], + ], + ) + def test_retry( + self, + mocker: MockerFixture, + n_retries: int, + n_inner_fails: int, + expected_calls: int, + expected_result: Literal[True] | None, + ): + try_func = FailThenSuccess(n_inner_fails) + try_func_spy = mocker.spy(try_func, "run") + on_error_stub = mocker.stub("on_error") + + retry_wrapper = retry(num_retries=n_retries, on_error=on_error_stub) + func = retry_wrapper(try_func.run) + result = func() + + assert result is expected_result + assert try_func_spy.call_count == expected_calls + assert on_error_stub.call_count == min(n_inner_fails, n_retries + 1) diff --git a/tests/unit/utils/test_top_n.py b/tests/unit/utils/test_top_n.py new file mode 100644 index 0000000..c389ee3 --- /dev/null +++ b/tests/unit/utils/test_top_n.py @@ -0,0 +1,98 @@ +from mblm.utils.top_n import TopN + + +class TestTopN: + def test_top_n_docstring_example(self): + # create the heap + top_n = TopN[str](2) + + # add items in random order + top_n.add((1, "a")) + top_n.add((3, "b")) + top_n.add((2, "c")) + + for idx, item in enumerate(top_n): + ... + # iteration 1: (1, "a") + # iteration 2: (2, "c") + + it = iter(top_n) + assert next(it) == (1, "a") + assert next(it) == (2, "c") + assert top_n.get_top(1) == [(1, "a")] + + def test_top_n_min(self): + num_items = 2 + top_n = TopN[str](num_items) + top_n.add((1.0, "1")) + top_n.add((3.0, "3")) + top_n.add((2.0, "2")) + assert len(top_n.get_top()) == num_items + assert len(top_n.get_top()) == len(top_n) + assert top_n.get_top(1) == [(1.0, "1")] + assert top_n.get_top(2) == [(1.0, "1"), (2.0, "2")] + + it = iter(top_n) + assert next(it) == (1.0, "1") + assert next(it) == (2.0, "2") + + # add more items, which should go in front + top_n.add((0, "0")) + assert len(top_n.get_top()) == num_items + it = iter(top_n) + assert next(it) == (0.0, "0") + assert next(it) == (1.0, "1") + + def test_top_n_max(self): + num_items = 2 + top_n = TopN[str](num_items, top_largest=True) + top_n.add((1.0, "1")) + top_n.add((3.0, "3")) + top_n.add((2.0, "2")) + assert len(top_n.get_top()) == num_items + assert len(top_n.get_top()) == len(top_n) + assert top_n.get_top(1) == [(3.0, "3")] + assert top_n.get_top(2) == [(3.0, "3"), (2.0, "2")] + + it = iter(top_n) + assert next(it) == (3.0, "3") + assert next(it) == (2.0, "2") + + # add more items, which should go to the back + top_n.add((4, "4")) + assert len(top_n.get_top()) == num_items + it = iter(top_n) + assert next(it) == (4.0, "4") + assert next(it) == (3.0, "3") + + def test_top_n_min_equal_stores_first(self): + num_items = 3 + top_n = TopN[str](num_items) + top_n.add((2.0, "2a")) + top_n.add((3.0, "3")) + top_n.add((1.0, "1")) + top_n.add((2.0, "2b")) + assert len(top_n.get_top()) == num_items + + it = iter(top_n) + assert next(it) == (1.0, "1") + assert next(it) == (2.0, "2a") + assert next(it) == (2.0, "2b") + + def test_top_n_min_deep_copy(self): + top_n_shallow = TopN[list[int]](1) + lst = [8] + top_n_shallow.add((1.0, lst)) + + # mutate a reference in place + lst[0] = 9 + assert top_n_shallow.get_top(1) == [(1.0, [9])] + assert top_n_shallow.get_top(1)[0][1] is lst + + top_n_deep = TopN[list[int]](1, deep_copy=True) + top_n_deep.add((1.0, lst)) + + # mutate a reference in place back to original + lst[0] = 8 + assert top_n_deep.get_top(1) == [(1.0, [9])] + assert top_n_deep.get_top(1)[0][1] is not lst diff --git a/uv.lock b/uv.lock index 8ab5057..9074be9 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,15 @@ resolution-markers = [ "python_full_version >= '3.13' and sys_platform == 'linux'", ] +[[package]] +name = "absl-py" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/8f/fc001b92ecc467cc32ab38398bd0bfb45df46e7523bf33c2ad22a505f06e/absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff", size = 118055 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308", size = 133706 }, +] + [[package]] name = "altair" version = "5.4.1" @@ -324,6 +333,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/76/e6222113b83e3622caa4bb41032d0b1bf785250607392e1b778aca0b8a7d/charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc", size = 48543 }, ] +[[package]] +name = "click" +version = "8.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -683,6 +704,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/80/3a54838c3fb461f6fec263ebf3a3a41771bd05190238de3486aae8540c36/jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d", size = 133271 }, ] +[[package]] +name = "joblib" +version = "1.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/64/33/60135848598c076ce4b231e1b1895170f45fbcaeaa2c9d5e38b04db70c35/joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e", size = 2116621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, +] + [[package]] name = "json5" version = "0.9.25" @@ -1031,6 +1061,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-mock" }, { name = "python-dotenv" }, + { name = "rouge-score" }, { name = "ruff" }, { name = "types-pyyaml" }, { name = "types-tabulate" }, @@ -1062,6 +1093,7 @@ dev = [ { name = "pytest-cov", specifier = ">=5.0.0" }, { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, + { name = "rouge-score", specifier = ">=0.1.2" }, { name = "ruff", specifier = ">=0.6.8" }, { name = "types-pyyaml", specifier = ">=6.0.12.20240917" }, { name = "types-tabulate", specifier = ">=0.9.0.20240106" }, @@ -1236,6 +1268,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/e9/5f72929373e1a0e8d142a130f3f97e6ff920070f87f91c4e13e40e0fba5a/networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2", size = 1702396 }, ] +[[package]] +name = "nltk" +version = "3.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "joblib" }, + { name = "regex" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -1693,16 +1740,15 @@ wheels = [ [[package]] name = "protobuf" -version = "5.28.2" +version = "3.20.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b1/a4/4579a61de526e19005ceeb93e478b61d77aa38c8a85ad958ff16a9906549/protobuf-5.28.2.tar.gz", hash = "sha256:59379674ff119717404f7454647913787034f03fe7049cbef1d74a97bb4593f0", size = 422494 } +sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 } wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/30/231764750e0987755b7b8d66771f161e5f002e165d27b72154c776dbabf7/protobuf-5.28.2-cp310-abi3-win32.whl", hash = "sha256:eeea10f3dc0ac7e6b4933d32db20662902b4ab81bf28df12218aa389e9c2102d", size = 419662 }, - { url = "https://files.pythonhosted.org/packages/7d/46/3fdf7462160135aee6a530f1ec66665b5b4132fa2e1002ab971bc6ec2589/protobuf-5.28.2-cp310-abi3-win_amd64.whl", hash = "sha256:2c69461a7fcc8e24be697624c09a839976d82ae75062b11a0972e41fd2cd9132", size = 431479 }, - { url = "https://files.pythonhosted.org/packages/37/45/d2a760580f8f2ed2825ba44cb370e0a4011ddef85e728f46ea3dd565a8a5/protobuf-5.28.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8b9403fc70764b08d2f593ce44f1d2920c5077bf7d311fefec999f8c40f78b7", size = 414736 }, - { url = "https://files.pythonhosted.org/packages/e6/23/ed718dc18e6a561445ece1e7a17d2dda0c634ad9cf663102b47f10005d8f/protobuf-5.28.2-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:35cfcb15f213449af7ff6198d6eb5f739c37d7e4f1c09b5d0641babf2cc0c68f", size = 316518 }, - { url = "https://files.pythonhosted.org/packages/23/08/a1ce0415a115c2b703bfa798f06f0e43ca91dbe29d6180bf86a9287b15e2/protobuf-5.28.2-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:5e8a95246d581eef20471b5d5ba010d55f66740942b95ba9b872d918c459452f", size = 316605 }, - { url = "https://files.pythonhosted.org/packages/9b/55/f24e3b801d2e108c48aa2b1b59bb791b5cffba89465cbbf66fc98de89270/protobuf-5.28.2-py3-none-any.whl", hash = "sha256:52235802093bd8a2811abbe8bf0ab9c5f54cca0a751fdd3f6ac2a21438bffece", size = 169566 }, + { url = "https://files.pythonhosted.org/packages/28/55/b80e8567ec327c060fa39b242392e25690c8899c489ecd7bb65b46b7bb55/protobuf-3.20.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99", size = 918427 }, + { url = "https://files.pythonhosted.org/packages/31/be/80a9c6f16dfa4d41be3edbe655349778ae30882407fa8275eb46b4d34854/protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e", size = 1051042 }, + { url = "https://files.pythonhosted.org/packages/db/96/948d3fcc1fa816e7ae1d27af59b9d8c5c5e582f3994fd14394f31da95b99/protobuf-3.20.3-cp310-cp310-win32.whl", hash = "sha256:28545383d61f55b57cf4df63eebd9827754fd2dc25f80c5253f9184235db242c", size = 780167 }, + { url = "https://files.pythonhosted.org/packages/6f/5e/fc6feb366b0a9f28e0a2de3b062667c521cd9517d4ff55077b8f351ba2f3/protobuf-3.20.3-cp310-cp310-win_amd64.whl", hash = "sha256:67a3598f0a2dcbc58d02dd1928544e7d88f764b47d4a286202913f0b2801c2e7", size = 904029 }, + { url = "https://files.pythonhosted.org/packages/8d/14/619e24a4c70df2901e1f4dbc50a6291eb63a759172558df326347dce1f0d/protobuf-3.20.3-py2.py3-none-any.whl", hash = "sha256:a7ca6d488aa8ff7f329d4c545b2dbad8ac31464f1d8b1c87ad1346717731e4db", size = 162128 }, ] [[package]] @@ -2115,6 +2161,75 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de", size = 26684 }, ] +[[package]] +name = "regex" +version = "2024.11.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/5f/bd69653fbfb76cf8604468d3b4ec4c403197144c7bfe0e6a5fc9e02a07cb/regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519", size = 399494 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/3c/4651f6b130c6842a8f3df82461a8950f923925db8b6961063e82744bddcc/regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91", size = 482674 }, + { url = "https://files.pythonhosted.org/packages/15/51/9f35d12da8434b489c7b7bffc205c474a0a9432a889457026e9bc06a297a/regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0", size = 287684 }, + { url = "https://files.pythonhosted.org/packages/bd/18/b731f5510d1b8fb63c6b6d3484bfa9a59b84cc578ac8b5172970e05ae07c/regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e", size = 284589 }, + { url = "https://files.pythonhosted.org/packages/78/a2/6dd36e16341ab95e4c6073426561b9bfdeb1a9c9b63ab1b579c2e96cb105/regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde", size = 782511 }, + { url = "https://files.pythonhosted.org/packages/1b/2b/323e72d5d2fd8de0d9baa443e1ed70363ed7e7b2fb526f5950c5cb99c364/regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e", size = 821149 }, + { url = "https://files.pythonhosted.org/packages/90/30/63373b9ea468fbef8a907fd273e5c329b8c9535fee36fc8dba5fecac475d/regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2", size = 809707 }, + { url = "https://files.pythonhosted.org/packages/f2/98/26d3830875b53071f1f0ae6d547f1d98e964dd29ad35cbf94439120bb67a/regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf", size = 781702 }, + { url = "https://files.pythonhosted.org/packages/87/55/eb2a068334274db86208ab9d5599ffa63631b9f0f67ed70ea7c82a69bbc8/regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c", size = 771976 }, + { url = "https://files.pythonhosted.org/packages/74/c0/be707bcfe98254d8f9d2cff55d216e946f4ea48ad2fd8cf1428f8c5332ba/regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86", size = 697397 }, + { url = "https://files.pythonhosted.org/packages/49/dc/bb45572ceb49e0f6509f7596e4ba7031f6819ecb26bc7610979af5a77f45/regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67", size = 768726 }, + { url = "https://files.pythonhosted.org/packages/5a/db/f43fd75dc4c0c2d96d0881967897926942e935d700863666f3c844a72ce6/regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d", size = 775098 }, + { url = "https://files.pythonhosted.org/packages/99/d7/f94154db29ab5a89d69ff893159b19ada89e76b915c1293e98603d39838c/regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2", size = 839325 }, + { url = "https://files.pythonhosted.org/packages/f7/17/3cbfab1f23356fbbf07708220ab438a7efa1e0f34195bf857433f79f1788/regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008", size = 843277 }, + { url = "https://files.pythonhosted.org/packages/7e/f2/48b393b51900456155de3ad001900f94298965e1cad1c772b87f9cfea011/regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62", size = 773197 }, + { url = "https://files.pythonhosted.org/packages/45/3f/ef9589aba93e084cd3f8471fded352826dcae8489b650d0b9b27bc5bba8a/regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e", size = 261714 }, + { url = "https://files.pythonhosted.org/packages/42/7e/5f1b92c8468290c465fd50c5318da64319133231415a8aa6ea5ab995a815/regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519", size = 274042 }, + { url = "https://files.pythonhosted.org/packages/58/58/7e4d9493a66c88a7da6d205768119f51af0f684fe7be7bac8328e217a52c/regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638", size = 482669 }, + { url = "https://files.pythonhosted.org/packages/34/4c/8f8e631fcdc2ff978609eaeef1d6994bf2f028b59d9ac67640ed051f1218/regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7", size = 287684 }, + { url = "https://files.pythonhosted.org/packages/c5/1b/f0e4d13e6adf866ce9b069e191f303a30ab1277e037037a365c3aad5cc9c/regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20", size = 284589 }, + { url = "https://files.pythonhosted.org/packages/25/4d/ab21047f446693887f25510887e6820b93f791992994f6498b0318904d4a/regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114", size = 792121 }, + { url = "https://files.pythonhosted.org/packages/45/ee/c867e15cd894985cb32b731d89576c41a4642a57850c162490ea34b78c3b/regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3", size = 831275 }, + { url = "https://files.pythonhosted.org/packages/b3/12/b0f480726cf1c60f6536fa5e1c95275a77624f3ac8fdccf79e6727499e28/regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f", size = 818257 }, + { url = "https://files.pythonhosted.org/packages/bf/ce/0d0e61429f603bac433910d99ef1a02ce45a8967ffbe3cbee48599e62d88/regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0", size = 792727 }, + { url = "https://files.pythonhosted.org/packages/e4/c1/243c83c53d4a419c1556f43777ccb552bccdf79d08fda3980e4e77dd9137/regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55", size = 780667 }, + { url = "https://files.pythonhosted.org/packages/c5/f4/75eb0dd4ce4b37f04928987f1d22547ddaf6c4bae697623c1b05da67a8aa/regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89", size = 776963 }, + { url = "https://files.pythonhosted.org/packages/16/5d/95c568574e630e141a69ff8a254c2f188b4398e813c40d49228c9bbd9875/regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d", size = 784700 }, + { url = "https://files.pythonhosted.org/packages/8e/b5/f8495c7917f15cc6fee1e7f395e324ec3e00ab3c665a7dc9d27562fd5290/regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34", size = 848592 }, + { url = "https://files.pythonhosted.org/packages/1c/80/6dd7118e8cb212c3c60b191b932dc57db93fb2e36fb9e0e92f72a5909af9/regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d", size = 852929 }, + { url = "https://files.pythonhosted.org/packages/11/9b/5a05d2040297d2d254baf95eeeb6df83554e5e1df03bc1a6687fc4ba1f66/regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45", size = 781213 }, + { url = "https://files.pythonhosted.org/packages/26/b7/b14e2440156ab39e0177506c08c18accaf2b8932e39fb092074de733d868/regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9", size = 261734 }, + { url = "https://files.pythonhosted.org/packages/80/32/763a6cc01d21fb3819227a1cc3f60fd251c13c37c27a73b8ff4315433a8e/regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60", size = 274052 }, + { url = "https://files.pythonhosted.org/packages/ba/30/9a87ce8336b172cc232a0db89a3af97929d06c11ceaa19d97d84fa90a8f8/regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a", size = 483781 }, + { url = "https://files.pythonhosted.org/packages/01/e8/00008ad4ff4be8b1844786ba6636035f7ef926db5686e4c0f98093612add/regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9", size = 288455 }, + { url = "https://files.pythonhosted.org/packages/60/85/cebcc0aff603ea0a201667b203f13ba75d9fc8668fab917ac5b2de3967bc/regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2", size = 284759 }, + { url = "https://files.pythonhosted.org/packages/94/2b/701a4b0585cb05472a4da28ee28fdfe155f3638f5e1ec92306d924e5faf0/regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4", size = 794976 }, + { url = "https://files.pythonhosted.org/packages/4b/bf/fa87e563bf5fee75db8915f7352e1887b1249126a1be4813837f5dbec965/regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577", size = 833077 }, + { url = "https://files.pythonhosted.org/packages/a1/56/7295e6bad94b047f4d0834e4779491b81216583c00c288252ef625c01d23/regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3", size = 823160 }, + { url = "https://files.pythonhosted.org/packages/fb/13/e3b075031a738c9598c51cfbc4c7879e26729c53aa9cca59211c44235314/regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e", size = 796896 }, + { url = "https://files.pythonhosted.org/packages/24/56/0b3f1b66d592be6efec23a795b37732682520b47c53da5a32c33ed7d84e3/regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe", size = 783997 }, + { url = "https://files.pythonhosted.org/packages/f9/a1/eb378dada8b91c0e4c5f08ffb56f25fcae47bf52ad18f9b2f33b83e6d498/regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e", size = 781725 }, + { url = "https://files.pythonhosted.org/packages/83/f2/033e7dec0cfd6dda93390089864732a3409246ffe8b042e9554afa9bff4e/regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29", size = 789481 }, + { url = "https://files.pythonhosted.org/packages/83/23/15d4552ea28990a74e7696780c438aadd73a20318c47e527b47a4a5a596d/regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39", size = 852896 }, + { url = "https://files.pythonhosted.org/packages/e3/39/ed4416bc90deedbfdada2568b2cb0bc1fdb98efe11f5378d9892b2a88f8f/regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51", size = 860138 }, + { url = "https://files.pythonhosted.org/packages/93/2d/dd56bb76bd8e95bbce684326302f287455b56242a4f9c61f1bc76e28360e/regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad", size = 787692 }, + { url = "https://files.pythonhosted.org/packages/0b/55/31877a249ab7a5156758246b9c59539abbeba22461b7d8adc9e8475ff73e/regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54", size = 262135 }, + { url = "https://files.pythonhosted.org/packages/38/ec/ad2d7de49a600cdb8dd78434a1aeffe28b9d6fc42eb36afab4a27ad23384/regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b", size = 273567 }, + { url = "https://files.pythonhosted.org/packages/90/73/bcb0e36614601016552fa9344544a3a2ae1809dc1401b100eab02e772e1f/regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84", size = 483525 }, + { url = "https://files.pythonhosted.org/packages/0f/3f/f1a082a46b31e25291d830b369b6b0c5576a6f7fb89d3053a354c24b8a83/regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4", size = 288324 }, + { url = "https://files.pythonhosted.org/packages/09/c9/4e68181a4a652fb3ef5099e077faf4fd2a694ea6e0f806a7737aff9e758a/regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0", size = 284617 }, + { url = "https://files.pythonhosted.org/packages/fc/fd/37868b75eaf63843165f1d2122ca6cb94bfc0271e4428cf58c0616786dce/regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0", size = 795023 }, + { url = "https://files.pythonhosted.org/packages/c4/7c/d4cd9c528502a3dedb5c13c146e7a7a539a3853dc20209c8e75d9ba9d1b2/regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7", size = 833072 }, + { url = "https://files.pythonhosted.org/packages/4f/db/46f563a08f969159c5a0f0e722260568425363bea43bb7ae370becb66a67/regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7", size = 823130 }, + { url = "https://files.pythonhosted.org/packages/db/60/1eeca2074f5b87df394fccaa432ae3fc06c9c9bfa97c5051aed70e6e00c2/regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c", size = 796857 }, + { url = "https://files.pythonhosted.org/packages/10/db/ac718a08fcee981554d2f7bb8402f1faa7e868c1345c16ab1ebec54b0d7b/regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3", size = 784006 }, + { url = "https://files.pythonhosted.org/packages/c2/41/7da3fe70216cea93144bf12da2b87367590bcf07db97604edeea55dac9ad/regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07", size = 781650 }, + { url = "https://files.pythonhosted.org/packages/a7/d5/880921ee4eec393a4752e6ab9f0fe28009435417c3102fc413f3fe81c4e5/regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e", size = 789545 }, + { url = "https://files.pythonhosted.org/packages/dc/96/53770115e507081122beca8899ab7f5ae28ae790bfcc82b5e38976df6a77/regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6", size = 853045 }, + { url = "https://files.pythonhosted.org/packages/31/d3/1372add5251cc2d44b451bd94f43b2ec78e15a6e82bff6a290ef9fd8f00a/regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4", size = 860182 }, + { url = "https://files.pythonhosted.org/packages/ed/e3/c446a64984ea9f69982ba1a69d4658d5014bc7a0ea468a07e1a1265db6e2/regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d", size = 787733 }, + { url = "https://files.pythonhosted.org/packages/2b/f1/e40c8373e3480e4f29f2692bd21b3e05f296d3afebc7e5dcf21b9756ca1c/regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff", size = 262122 }, + { url = "https://files.pythonhosted.org/packages/45/94/bc295babb3062a731f52621cdc992d123111282e291abaf23faa413443ea/regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a", size = 273545 }, +] + [[package]] name = "requests" version = "2.32.3" @@ -2151,6 +2266,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9", size = 4242 }, ] +[[package]] +name = "rouge-score" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "nltk" }, + { name = "numpy" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz", hash = "sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04", size = 17400 } + [[package]] name = "rpds-py" version = "0.20.0"