diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..2da60a5 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +@shamanez @Jacobsolawetz @ben-epstein @SachiraKuruppu @metric-space \ No newline at end of file diff --git a/.requirements.txt b/.requirements.txt deleted file mode 100644 index 1f92cbb..0000000 --- a/.requirements.txt +++ /dev/null @@ -1,201 +0,0 @@ -absl-py==1.4.0 -accelerate==0.22.0 -aiohttp==3.8.5 -aiosignal==1.3.1 -alembic==1.12.0 -appdirs==1.4.4 -async-timeout==4.0.3 -attrs==19.3.0 -Automat==0.8.0 -awscli==1.27.101 -bitsandbytes==0.41.1 -blinker==1.6.2 -boto3==1.28.36 -botocore==1.31.36 -build==0.10.0 -CacheControl==0.13.1 -cachetools==5.3.1 -certifi==2019.11.28 -chardet==3.0.4 -charset-normalizer==3.2.0 -cleo==2.0.1 -cloud-init==23.2.2 -cloudpickle==2.2.1 -cmake==3.27.2 -colorama==0.4.3 -command-not-found==0.3 -configobj==5.0.6 -constantly==15.1.0 -contourpy==1.1.0 -crashtest==0.4.1 -cryptography==2.8 -cycler==0.11.0 -databricks-cli==0.17.7 -datasets==2.14.4 -dbus-python==1.2.16 -devscripts===2.20.2ubuntu2 -dill==0.3.7 -distlib==0.3.7 -distro==1.4.0 -distro-info==0.23+ubuntu1.1 -docker==6.1.3 -docker-pycreds==0.4.0 -docutils==0.16 -dulwich==0.21.5 -ec2-hibinit-agent==1.0.0 -einops==0.6.1 -entrypoints==0.3 -filelock==3.12.2 -Flask==2.3.3 -fonttools==4.42.1 -frozenlist==1.4.0 -fsspec==2023.6.0 -gitdb==4.0.10 -GitPython==3.1.32 -google-auth==2.22.0 -google-auth-oauthlib==1.0.0 -gpg==1.13.1 -greenlet==2.0.2 -grpcio==1.57.0 -gunicorn==21.2.0 -hibagent==1.0.1 -httplib2==0.14.0 -huggingface-hub==0.16.4 -hyperlink==19.0.0 -idna==2.8 -importlib-metadata==6.8.0 -importlib-resources==6.0.1 -incremental==16.10.1 -installer==0.7.0 -itsdangerous==2.1.2 -jaraco.classes==3.3.0 -jeepney==0.8.0 -Jinja2==3.1.2 -jmespath==1.0.1 -joblib==1.3.2 -jsonpatch==1.22 -jsonpointer==2.0 -jsonschema==4.17.3 -keyring==24.2.0 -kiwisolver==1.4.5 -language-selector==0.1 -launchpadlib==1.10.13 -lazr.restfulclient==0.14.2 -lazr.uri==1.0.3 -lit==16.0.6 -Mako==1.2.4 -Markdown==3.4.4 -MarkupSafe==2.1.3 -matplotlib==3.7.2 -mlflow==2.6.0 -more-itertools==4.2.0 -mpmath==1.3.0 -msgpack==1.0.5 -multidict==6.0.4 -multiprocess==0.70.15 -netifaces==0.10.4 -networkx==3.1 -numpy==1.24.4 -nvidia-cublas-cu11==11.10.3.66 -nvidia-cuda-cupti-cu11==11.7.101 -nvidia-cuda-nvrtc-cu11==11.7.99 -nvidia-cuda-runtime-cu11==11.7.99 -nvidia-cudnn-cu11==8.5.0.96 -nvidia-cufft-cu11==10.9.0.58 -nvidia-curand-cu11==10.2.10.91 -nvidia-cusolver-cu11==11.4.0.1 -nvidia-cusparse-cu11==11.7.4.91 -nvidia-nccl-cu11==2.14.3 -nvidia-nvtx-cu11==11.7.91 -oauthlib==3.1.0 -packaging==23.1 -pandas==2.0.3 -pathtools==0.1.2 -pexpect==4.8.0 -Pillow==10.0.0 -pkginfo==1.9.6 -pkgutil_resolve_name==1.3.10 -platformdirs==3.10.0 -poetry-core==1.7.0 -poetry-plugin-export==1.5.0 -protobuf==4.24.0 -psutil==5.9.5 -ptyprocess==0.7.0 -pyarrow==12.0.1 -pyasn1==0.4.2 -pyasn1-modules==0.2.1 -PyGObject==3.36.0 -PyHamcrest==1.9.0 -PyJWT==1.7.1 -pymacaroons==0.13.0 -PyNaCl==1.3.0 -pyOpenSSL==19.0.0 -pyparsing==3.0.9 -pyproject_hooks==1.0.0 -pyrsistent==0.15.5 -pyserial==3.4 -python-apt==2.0.1+ubuntu0.20.4.1 -python-dateutil==2.8.2 -python-debian==0.1.36+ubuntu1.1 -python-magic==0.4.16 -pytz==2023.3 -pyxdg==0.26 -PyYAML==5.3.1 -querystring-parser==1.2.4 -rapidfuzz==2.15.1 -regex==2023.8.8 -requests==2.31.0 -requests-oauthlib==1.3.1 -requests-toolbelt==1.0.0 -requests-unixsocket==0.2.0 -rsa==4.7.2 -s3transfer==0.6.0 -safetensors==0.3.3 -scikit-learn==1.3.0 -scipy==1.10.1 -SecretStorage==3.3.3 -sentry-sdk==1.29.2 -service-identity==18.1.0 -setproctitle==1.3.2 -shellingham==1.5.3 -simplejson==3.16.0 -six==1.14.0 -smmap==5.0.0 -sos==4.5.6 -SQLAlchemy==2.0.20 -sqlparse==0.4.4 -ssh-import-id==5.10 -sympy==1.12 -systemd-python==234 -tabulate==0.9.0 -tensorboard==2.14.0 -tensorboard-data-server==0.7.1 -threadpoolctl==3.2.0 -tokenizers==0.13.3 -tomli==2.0.1 -tomlkit==0.12.1 -torch==2.0.1 -torch-tb-profiler==0.4.1 -tqdm==4.66.1 -transformers==4.32.0 -triton==2.0.0 --e git+https://github.com/huggingface/trl.git@34e6948d459540a21f80c5be227fb4da039dd97a#egg=trl -trove-classifiers==2023.8.7 -Twisted==18.9.0 -typing_extensions==4.7.1 -tzdata==2023.3 -ubuntu-advantage-tools==8001 -ufw==0.36 -unattended-upgrades==0.1 -unidiff==0.5.5 -urllib3==1.26.16 -virtualenv==20.24.3 -wadllib==1.3.3 -wandb==0.15.8 -websocket-client==1.6.2 -Werkzeug==2.3.7 -xxhash==3.3.0 -yarl==1.9.2 -zipp==3.16.2 -zope.interface==4.7.1 -sklearn==0.0.post9 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9719e95 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,25 @@ +# Contributing to DALM + +Thanks for helping out! We're excited for your issues and PRs + +## Building from local + +Building the repo is straightforward. Clone the repo, and install the package. We use [invoke](https://github.com/pyinvoke/invoke) to manage `DALM` +```shell +git clone https://github.com/arcee-ai/DALM.git && cd DALM +pip install invoke +inv install +``` +This will install the repo, with its dev dependencies, in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) (for live updates on code changes) + +## Format, lint, test +Because we use `invoke`, the following is all you need to prepare for a pr +```shell +inv format # black, ruff +inv lint # black check, ruff check, mypy +inv test # pytest +``` + +We require 95% test coverage for all PRs. + +For more information around our `invoke` commands, see [`tasks.py`](https://github.com/arcee-ai/DALM/blob/main/tasks.py) and our [`pyproject.toml`](https://github.com/arcee-ai/DALM/blob/main/pyproject.toml) configuration \ No newline at end of file diff --git a/License b/LICENSE similarity index 100% rename from License rename to LICENSE diff --git a/README.md b/README.md index a9fb544..1d10df2 100644 --- a/README.md +++ b/README.md @@ -26,30 +26,43 @@ For the first time in the literature, we modified the initial RAG-end2end model - Additionally, we have data processing codes and synthetic data generation code inside the [datasets](https://github.com/arcee-ai/DALM/tree/main/dalm/datasets) folder. -## Code execution -To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps. +# Usage +To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps. -- The setup for training and evaluation can be effortlessly executed provided you possess a [CSV](https://github.com/arcee-ai/DALM/tree/main/dalm/datasets/toy_dataset_train.py) file containing three columns: passage, query, and answer. You can utilize the script [question_answer_generation.py](https://github.com/arcee-ai/DALM/blob/main/dalm/datasets/qa_gen/question_answer_generation.py) to generate this CSV. -- It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns. -- In our experiments, we utilize BAAI/bge-large-en as the retriever and employ meta-llama/Llama-2-7b-hf as the generator. It's important to note that this code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models. +## Installation + +You can install this repo directly via `pip install indomain` -## Clone the repositary -- `git clone https://github.com/arcee-ai/DALM.git` -- `cd DALM` +Alternatively, for development or research, you can clone and install the repo locally: +```shell +git clone https://github.com/arcee-ai/DALM.git && cd DALM +pip install --upgrade -e . +``` +This will install the DALM repo and all necessary dependencies. -## Install the necesarry libraries -Create your desired virtual environment isntall all necasary librries. -- `pip install -r requirements.txt` +Make sure things are installed correctly by running `dalm version` + +## Data setup +### tl;dr +You can run `dalm qa-gen ` to preprocess your dataset for training. See `dalm qa-gen --help` for more options +
If you do not have a dataset, you can start with ours +```shell +dalm qa-gen dalm/datasets/toy_data_train.csv +``` +- The setup for training and evaluation can be effortlessly executed provided you possess a [CSV](https://github.com/arcee-ai/DALM/tree/main/dalm/datasets/toy_data_train.csv) file containing two/three columns: `Passage`, `Query` (and `Answer` if running e2e). You can utilize the script [question_answer_generation.py](https://github.com/arcee-ai/DALM/blob/main/dalm/datasets/qa_gen/question_answer_generation.py) to generate this CSV. +- It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns. +- In our experiments, we utilize `BAAI/bge-large-en` as the default retriever and employ `meta-llama/Llama-2-7b-hf` as the default generator. The code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models. ## Training +You can leverage our scripts directly if you'd like, or you can use the `dalm` cli. The arguments for both are identical + ### Train Retriever Only Train `BAAI/bge-large-en` retriever with contrastive learning. - -``` -python dalm/training/retriever_only/train_retriever_only.py ---train_dataset_csv_path ./dalm/datasets/toy_data_train.csv" \ +```shell +python dalm/training/retriever_only/train_retriever_only.py \ +--train_dataset_csv_path "./dalm/datasets/toy_data_train.csv" \ --model_name_or_path "BAAI/bge-large-en" \ --output_dir "./dalm/training/rag_e2e/retriever_only_checkpoints" \ --use_peft \ @@ -57,12 +70,22 @@ python dalm/training/retriever_only/train_retriever_only.py --report_to all \ --per_device_train_batch_size 150 ``` +or +```shell +dalm train-retriever-only "BAAI/bge-large-en" "./dalm/datasets/toy_data_train.csv" \ +--output-dir "./dalm/training/rag_e2e/retriever_only_checkpoints" \ +--use-peft \ +--with-tracking \ +--report-to all \ +--per-device-train-batch-size 150 +``` -### Train Retriever and Generator Jointly (RAG-e2e) +For all available arguments and options, see `dalm train-retriever-only --help` +### Train Retriever and Generator Jointly (RAG-e2e) Train `Llama-2-7b` generator jointly with the retriever model `BAAI/bge-large-en`. -``` +```shell python dalm/training/rag_e2e/train_rage2e.py \ --dataset_path "./dalm/datasets/toy_data_train.csv" \ --retriever_name_or_path "BAAI/bge-large-en" \ @@ -72,6 +95,20 @@ python dalm/training/rag_e2e/train_rage2e.py \ --report_to all \ --per_device_train_batch_size 24 ``` +or +```shell +dalm train-rag-e2e \ +"./dalm/datasets/toy_data_train.csv" \ +"BAAI/bge-large-en" \ +"meta-llama/Llama-2-7b-hf" \ +--output-dir "./dalm/training/rag_e2e/rag_e2e_checkpoints" \ +--with-tracking \ +--report-to all \ +--per-device-train-batch-size 24 +``` + +For all available arguments and options, see `dalm train-rag-e2e --help` + ## Evaluation Here's a summary of evaluation results on evaluating on a 200K line test csv of Patent abstracts @@ -86,7 +123,7 @@ To run retriever only eval (make sure you have the checkpoints in the project root) ```bash - python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints + python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints ``` For the e2e eval @@ -94,3 +131,7 @@ For the e2e eval ```bash python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints ``` + + +## Contributing +See [CONTRIBUTING](https://github.com/arcee-ai/DALM/tree/main/CONTRIBUTING.md) diff --git a/dalm/__init__.py b/dalm/__init__.py index 6c8e6b9..f102a9c 100644 --- a/dalm/__init__.py +++ b/dalm/__init__.py @@ -1 +1 @@ -__version__ = "0.0.0" +__version__ = "0.0.1" diff --git a/dalm/cli.py b/dalm/cli.py new file mode 100644 index 0000000..300acff --- /dev/null +++ b/dalm/cli.py @@ -0,0 +1,308 @@ +from enum import Enum +from pathlib import Path +from typing import Optional + +import typer +from transformers import SchedulerType +from typing_extensions import Annotated + +from dalm import __version__ +from dalm.datasets.qa_gen.question_answer_generation import generate_qa_from_disk +from dalm.training.rag_e2e.train_rage2e import train_e2e +from dalm.training.retriever_only.train_retriever_only import train_retriever + +cli = typer.Typer() +HERE = Path(__file__).parent + + +class DALMSchedulerType(Enum): + LINEAR = SchedulerType.LINEAR + COSINE = SchedulerType.COSINE + COSINE_WITH_RESTARTS = SchedulerType.COSINE_WITH_RESTARTS + POLYNOMIAL = SchedulerType.POLYNOMIAL + CONSTANT = SchedulerType.CONSTANT + CONSTANT_WITH_WARMUP = SchedulerType.CONSTANT_WITH_WARMUP + + +@cli.command() +def version() -> None: + """Print the current version of DALM""" + print(f"🐾You are running DALM version: {__version__}") + + +@cli.command() +def train_rag_e2e( + dataset_path: Annotated[ + str, + typer.Argument( + help="Path to the dataset to train with. Can be a huggingface dataset directory or a csv file.", + show_default=False, + ), + ], + retriever_name_or_path: Annotated[ + str, + typer.Argument( + help="Path to pretrained retriever or identifier from huggingface.co/models.", show_default=False + ), + ], + generator_name_or_path: Annotated[ + str, + typer.Argument( + help="Path to pretrained (causal) generator or identifier from huggingface.co/models.", show_default=False + ), + ], + dataset_passage_col_name: Annotated[ + str, typer.Option(help="Name of the column containing the passage") + ] = "Abstract", + dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", + dataset_answer_col_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer", + query_max_len: Annotated[ + int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated") + ] = 50, + passage_max_len: Annotated[ + int, typer.Option(help="The max passage sequence length during tokenization. Longer sequences are truncated") + ] = 128, + generator_max_len: Annotated[ + int, + typer.Option( + help="The max generator input sequence length during tokenization. Longer sequences are truncated" + ), + ] = 256, + per_device_train_batch_size: Annotated[ + int, typer.Option(help="Batch size (per device) for the training dataloader.") + ] = 32, + learning_rate: Annotated[ + float, typer.Option(help="Initial learning rate (after the potential warmup period) to use.") + ] = 1e-4, + logit_scale: Annotated[int, typer.Option(help="Logit scale for the contrastive learning.")] = 100, + weight_decay: Annotated[float, typer.Option(help="Weight decay to use.")] = 0.0, + num_train_epochs: Annotated[int, typer.Option(help="Total number of training epochs to perform.")] = 1, + max_train_steps: Annotated[ + Optional[int], + typer.Option(help="Total number of training steps to perform. If provided, overrides num_train_epochs."), + ] = None, + gradient_accumulation_steps: Annotated[ + int, typer.Option(help="Number of updates steps to accumulate before performing a backward/update pass.") + ] = 1, + lr_scheduler_type: Annotated[ + DALMSchedulerType, typer.Option(help="The scheduler type to use.") + ] = DALMSchedulerType.LINEAR.value, + num_warmup_steps: Annotated[int, typer.Option(help="Number of steps for the warmup in the lr scheduler.")] = 100, + output_dir: Annotated[Optional[str], typer.Option(help="Where to store the final model.")] = None, + seed: Annotated[int, typer.Option(help="A seed for reproducible training.")] = 42, + hub_model_id: Annotated[ + Optional[str], + typer.Option( + help="[NOT CURRENTLY USED]. The name of the repository to keep in sync with the local `output_dir`." + ), + ] = None, + hub_token: Annotated[ + Optional[str], typer.Option(help="[NOT CURRENTLY USED]. The token to use to push to the Model Hub.") + ] = None, + checkpointing_steps: Annotated[ + Optional[str], + typer.Option( + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." + ), + ] = None, + resume_from_checkpoint: Annotated[ + Optional[str], typer.Option(help="If the training should continue from a checkpoint folder.") + ] = None, + with_tracking: Annotated[bool, typer.Option(help="Whether to enable experiment trackers for logging.")] = True, + report_to: Annotated[ + str, + typer.Option( + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`, ' + '`"wandb"`, `"mlflow"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all.' + "integrations. Only applicable when `--with_tracking` is passed." + ) + ), + ] = "all", + sanity_test: Annotated[ + bool, typer.Option(help="[NOT CURRENTLY USED]. Whether to sanity test the model after training") + ] = True, + use_peft: Annotated[bool, typer.Option(help="Whether to use Peft during fine-tuning.")] = True, +) -> None: + """End-to-end train an in-domain model, including the retreiver and generator""" + train_e2e( + dataset_or_path=dataset_path, + retriever_name_or_path=retriever_name_or_path, + generator_name_or_path=generator_name_or_path, + dataset_passage_col_name=dataset_passage_col_name, + dataset_query_col_name=dataset_query_col_name, + dataset_answer_col_name=dataset_answer_col_name, + query_max_len=query_max_len, + passage_max_len=passage_max_len, + generator_max_len=generator_max_len, + per_device_train_batch_size=per_device_train_batch_size, + learning_rate=learning_rate, + logit_scale=logit_scale, + weight_decay=weight_decay, + num_train_epochs=num_train_epochs, + max_train_steps=max_train_steps, + gradient_accumulation_steps=gradient_accumulation_steps, + lr_scheduler_type=lr_scheduler_type.value, + num_warmup_steps=num_warmup_steps, + output_dir=output_dir, + seed=seed, + hub_model_id=hub_model_id, + hub_token=hub_token, + checkpointing_steps=checkpointing_steps, + resume_from_checkpoint=resume_from_checkpoint, + with_tracking=with_tracking, + report_to=report_to, + sanity_test=sanity_test, + use_peft=use_peft, + ) + + +@cli.command() +def train_retriever_only( + model_name_or_path: Annotated[ + str, typer.Argument(help="Path to the model or identifier from huggingface.co/models.", show_default=False) + ], + train_dataset_csv_path: Annotated[ + str, + typer.Argument( + help="Path to the train dataset to train with. Can be a huggingface dataset directory or a csv file.", + show_default=False, + ), + ], + test_dataset_csv_path: Annotated[ + Optional[str], + typer.Option( + help="Optional path to the test dataset for training. Can be a huggingface dataset directory or a csv file." + ), + ] = None, + dataset_passage_col_name: Annotated[ + str, typer.Option(help="Name of the column containing the passage") + ] = "Abstract", + dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", + query_max_len: Annotated[ + int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated") + ] = 50, + passage_max_len: Annotated[ + int, typer.Option(help="The max passage sequence length during tokenization. Longer sequences are truncated") + ] = 128, + per_device_train_batch_size: Annotated[ + int, typer.Option(help="Batch size (per device) for the training dataloader.") + ] = 32, + learning_rate: Annotated[ + float, typer.Option(help="Initial learning rate (after the potential warmup period) to use.") + ] = 1e-4, + logit_scale: Annotated[int, typer.Option(help="Logit scale for the contrastive learning.")] = 100, + weight_decay: Annotated[float, typer.Option(help="Weight decay to use.")] = 0.0, + num_train_epochs: Annotated[int, typer.Option(help="Total number of training epochs to perform.")] = 3, + max_train_steps: Annotated[ + Optional[int], + typer.Option(help="Total number of training steps to perform. If provided, overrides num_train_epochs."), + ] = None, + gradient_accumulation_steps: Annotated[ + int, typer.Option(help="Number of updates steps to accumulate before performing a backward/update pass.") + ] = 1, + lr_scheduler_type: Annotated[ + DALMSchedulerType, typer.Option(help="The scheduler type to use.") + ] = DALMSchedulerType.LINEAR.value, + num_warmup_steps: Annotated[int, typer.Option(help="Number of steps for the warmup in the lr scheduler.")] = 0, + output_dir: Annotated[Optional[str], typer.Option(help="Where to store the final model.")] = None, + seed: Annotated[int, typer.Option(help="A seed for reproducible training.")] = 42, + hub_model_id: Annotated[ + Optional[str], + typer.Option( + help="[NOT CURRENTLY USED]. The name of the repository to keep in sync with the local `output_dir`." + ), + ] = None, + hub_token: Annotated[ + Optional[str], typer.Option(help="[NOT CURRENTLY USED]. The token to use to push to the Model Hub.") + ] = None, + checkpointing_steps: Annotated[ + Optional[str], + typer.Option( + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." + ), + ] = None, + resume_from_checkpoint: Annotated[ + Optional[str], typer.Option(help="If the training should continue from a checkpoint folder.") + ] = None, + with_tracking: Annotated[bool, typer.Option(help="Whether to enable experiment trackers for logging.")] = True, + report_to: Annotated[ + str, + typer.Option( + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`, ' + '`"wandb"`, `"mlflow"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all.' + "integrations. Only applicable when `--with_tracking` is passed." + ) + ), + ] = "all", + sanity_test: Annotated[ + bool, typer.Option(help="[NOT CURRENTLY USED]. Whether to sanity test the model after training") + ] = True, + use_peft: Annotated[bool, typer.Option(help="Whether to use Peft during fine-tuning.")] = True, +) -> None: + """End-to-end train an in-domain model, including the retriever and generator""" + train_retriever( + train_dataset_or_csv_path=train_dataset_csv_path, + test_dataset_or_csv_path=test_dataset_csv_path, + model_name_or_path=model_name_or_path, + dataset_passage_col_name=dataset_passage_col_name, + dataset_query_col_name=dataset_query_col_name, + query_max_len=query_max_len, + passage_max_len=passage_max_len, + per_device_train_batch_size=per_device_train_batch_size, + learning_rate=learning_rate, + logit_scale=logit_scale, + weight_decay=weight_decay, + num_train_epochs=num_train_epochs, + max_train_steps=max_train_steps, + gradient_accumulation_steps=gradient_accumulation_steps, + lr_scheduler_type=lr_scheduler_type.value, + num_warmup_steps=num_warmup_steps, + output_dir=output_dir, + seed=seed, + hub_model_id=hub_model_id, + hub_token=hub_token, + checkpointing_steps=checkpointing_steps, + resume_from_checkpoint=resume_from_checkpoint, + with_tracking=with_tracking, + report_to=report_to, + sanity_test=sanity_test, + use_peft=use_peft, + ) + + +@cli.command() +def qa_gen( + dataset_path: Annotated[ + str, + typer.Argument( + help="Path to the input dataset. Can be huggingface dataset directory, " + "path to a dataset on hub, or a csv file.", + show_default=False, + ), + ], + output_dir: Annotated[str, typer.Option(help="Output directory to store the resulting files")] = str(HERE), + passage_column_name: Annotated[str, typer.Option(help="Column name for the passage/text")] = "Abstract", + title_column_name: Annotated[str, typer.Option(help="Column name for the title of the full document")] = "Title", + batch_size: Annotated[ + int, typer.Option(help="Batch size (per device) for generating question answer pairs.") + ] = 100, + sample_size: Annotated[ + int, typer.Option(help="Number of examples to process. If the data has more samples, they will be dropped") + ] = 1000, + as_csv: Annotated[ + bool, + typer.Option( + help="Save the files as CSV. If False, will save them as a dataset directory via [`~Dataset.save_to_disk`]" + ), + ] = True, +) -> None: + """Generate question-answer pairs for a given input dataset""" + generate_qa_from_disk( + dataset_path, passage_column_name, title_column_name, sample_size, batch_size, output_dir, as_csv + ) + + +if __name__ == "__main__": + cli() diff --git a/dalm/datasets/qa_gen/question_answer_generation.py b/dalm/datasets/qa_gen/question_answer_generation.py index 6e9f8eb..b64928a 100644 --- a/dalm/datasets/qa_gen/question_answer_generation.py +++ b/dalm/datasets/qa_gen/question_answer_generation.py @@ -1,64 +1,75 @@ import argparse +import os.path +import warnings +from functools import partial +from pathlib import Path import datasets import torch +from datasets import Dataset, DatasetDict from sklearn.model_selection import train_test_split -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer -device = "cuda:0" if torch.cuda.is_available() else "cpu" +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" TEST_SIZE = 0.2 +QA_MODEL = "potsawee/t5-large-generation-squad-QuestionAnswer" -parser = argparse.ArgumentParser(description="Generate question answer pairs from the dataset of passages") -parser.add_argument( - "--dataset_path", - type=str, - default=None, - help="dataset path in the local dir. Can be huggingface dataset directory or a csv file.", - required=True, -) -parser.add_argument( - "--title_column_name", - type=str, - default="Title", - help="This title is used to identify passages from the same text", -) -parser.add_argument( - "--passage_column_name", - type=str, - default="Abstract", - help="name of the passage column", -) -parser.add_argument( - "--batch_size", - type=int, - default=1000, - help="Batch size (per device) for generating question answer pairs.", -) -parser.add_argument( - "--sample_size", - type=int, - default=1000, - help="Number of examples to process", -) -parser.add_argument( - "--output_dir", - type=str, - help="Output directory. Without '/' at the end", - required=True, -) -args = parser.parse_args() - -tokenizer = AutoTokenizer.from_pretrained("potsawee/t5-large-generation-squad-QuestionAnswer") -model = AutoModelForSeq2SeqLM.from_pretrained( - "potsawee/t5-large-generation-squad-QuestionAnswer", device_map="auto", load_in_8bit=True -) - - -def generate_question_answer_pairs(documents: dict) -> dict: + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate question answer pairs from the dataset of passages") + parser.add_argument( + "--dataset_path", + type=str, + default=None, + help="dataset path in the local dir. Can be huggingface dataset directory or a csv file.", + required=True, + ) + parser.add_argument( + "--title_column_name", + type=str, + default="Title", + help="This title is used to identify passages from the same text", + ) + parser.add_argument( + "--passage_column_name", + type=str, + default="Abstract", + help="name of the passage column", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="Batch size (per device) for generating question answer pairs.", + ) + parser.add_argument( + "--sample_size", + type=int, + default=1000, + help="Number of examples to process", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Output directory. Without '/' at the end", + required=True, + ) + parser.add_argument( + "--as_csv", + action="store_true", + help="Save the files as CSV. If False, will save them as a dataset directory via [`~Dataset.save_to_disk`]", + ) + args = parser.parse_args() + return args + + +def generate_question_answer_pairs( + documents: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, passage_column_name: str +) -> dict: """Generate question answer pairs""" - texts = documents[args.passage_column_name] + texts = documents[passage_column_name] - inputs = tokenizer(texts, return_tensors="pt", padding=True, max_length=150, truncation=True).to(device) + inputs = tokenizer(texts, return_tensors="pt", padding=True, max_length=150, truncation=True).to(DEVICE) outputs = model.generate(**inputs, max_new_tokens=50) question_answers = tokenizer.batch_decode(outputs, skip_special_tokens=False) question_answers = [ @@ -78,21 +89,18 @@ def generate_question_answer_pairs(documents: dict) -> dict: def filter_malformed_questions(record: dict) -> bool: - question = record["Question"] - answer = record["Answer"] + return record["Question"] != "-" and record["Answer"] != "-" - return question != "-" and answer != "-" - -def split_dataset(shuffled_dataset: datasets.Dataset, test_size: float = TEST_SIZE) -> datasets.DatasetDict: - unique_titles = set(shuffled_dataset[args.title_column_name]) +def split_dataset( + shuffled_dataset: datasets.Dataset, title_column_name: str, test_size: float = TEST_SIZE +) -> datasets.DatasetDict: + unique_titles = set(shuffled_dataset[title_column_name]) train_titles, test_titles = train_test_split(list(unique_titles), test_size=test_size, random_state=42) - train_dataset = shuffled_dataset.filter( - lambda example: example[args.title_column_name] in train_titles, num_proc=128 - ) - test_dataset = shuffled_dataset.filter(lambda example: example[args.title_column_name] in test_titles, num_proc=128) + train_dataset = shuffled_dataset.filter(lambda example: example[title_column_name] in train_titles, num_proc=128) + test_dataset = shuffled_dataset.filter(lambda example: example[title_column_name] in test_titles, num_proc=128) return datasets.DatasetDict( { @@ -102,31 +110,92 @@ def split_dataset(shuffled_dataset: datasets.Dataset, test_size: float = TEST_SI ) -dataset = datasets.load_dataset("csv", data_files={"data": args.dataset_path})["data"] - -# shuffle data -dataset.shuffle(seed=42) - -# select a subset -small_dataset = dataset.select(range(args.sample_size)) - -# train-test split -small_dataset_splits = split_dataset(small_dataset) - -print( - f"Train dataset size: {len(small_dataset_splits['train'])}, Test dataset size: {len(small_dataset_splits['test'])}" -) - -for split_name in small_dataset_splits: - processed_split = small_dataset_splits[split_name].map( - generate_question_answer_pairs, batched=True, batch_size=args.batch_size +def generate_qa_from_dataset( + dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int +) -> DatasetDict: + tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) + model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL, device_map="auto", load_in_8bit=True) + # shuffle data + dataset.shuffle(seed=42) + # select a subset + small_dataset = dataset.select(range(sample_size)) + # train-test split + small_dataset_splits = split_dataset(small_dataset, title_column_name) + print( + f"Train dataset size: {len(small_dataset_splits['train'])}, " + f"Test dataset size: {len(small_dataset_splits['test'])}" + ) + qa_gen_map = partial( + generate_question_answer_pairs, model=model, tokenizer=tokenizer, passage_column_name=passage_column_name + ) + processed_data = small_dataset_splits.map(qa_gen_map, batched=True, batch_size=batch_size) + filtered_data = processed_data.filter(filter_malformed_questions) + print( + f"Malformed question answer pairs: " + f"(train: {len(processed_data['train']) - len(filtered_data['train'])} " + f"test: {len(processed_data['test']) - len(filtered_data['test'])})" + ) + return processed_data + + +def _load_dataset_from_path(dataset_path: str) -> Dataset: + if dataset_path.endswith(".csv"): + dataset = Dataset.from_csv(dataset_path) + elif not os.path.splitext(dataset_path): + if os.path.isdir(dataset_path): + dataset = datasets.load_from_disk(dataset_path) + else: + dataset = datasets.load_dataset(dataset_path) + key = next(iter(dataset)) + if isinstance(dataset, DatasetDict): + warnings.warn(f"Found multiple keys in dataset. Generating qa for split {key}", stacklevel=0) + dataset = dataset[key] + else: + raise ValueError( + "dataset-path must be one of csv, dataset directory " + "(ie saved using [`~Dataset.save_to_disk`] or a dataset on the huggingface hub" + ) + return dataset + + +def generate_qa_from_disk( + dataset_path: str, + passage_column_name: str, + title_column_name: str, + sample_size: int, + batch_size: int, + output_dir: str, + as_csv: bool, +) -> None: + dataset = _load_dataset_from_path(dataset_path) + qa_gen_data = generate_qa_from_dataset(dataset, passage_column_name, title_column_name, sample_size, batch_size) + output_path = Path(output_dir) + output_path.mkdir(exist_ok=True) + for split_name, split_ds in qa_gen_data.items(): + full_path = f"{output_path}/question_answer_pairs_{split_name}" + if as_csv: + full_path = f"{full_path}.csv" + split_ds.to_csv(full_path) + else: + split_ds.save_to_disk(full_path) + print(f"Saving split {split_name} to {full_path}") + + +def main() -> None: + args = parse_args() + generate_qa_from_disk( + args.dataset_path, + args.passage_column_name, + args.title_column_name, + args.sample_size, + args.batch_size, + args.output_dir, + args.as_csv, ) - filtered_split = processed_split.filter(filter_malformed_questions) - print(f"Malformed question answer pairs: {len(processed_split) - len(filtered_split)}") - filtered_split.save_to_disk(f"{args.output_dir}/question_answer_pairs_{split_name}") - filtered_split.to_csv(f"{args.output_dir}/question_answer_pairs_{split_name}.csv") +if __name__ == "__main__": + main() """ python question_answer_generation.py \ diff --git a/dalm/eval/utils.py b/dalm/eval/utils.py index 98b6764..4d2d1f6 100644 --- a/dalm/eval/utils.py +++ b/dalm/eval/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, cast import hnswlib import numpy as np @@ -103,13 +103,15 @@ def mixed_collate_fn(batch: List[Dict[str, torch.Tensor | str]]) -> Dict[str, to """ This is able to account for string values which the default PyTorch collate_fn would silently ignore """ - new_batch = {} + new_batch: Dict[str, torch.Tensor | List[str]] = {} keys = batch[0].keys() for key in keys: if isinstance(batch[0][key], str) or batch[0][key] is None: - new_batch[key] = [sample[key] for sample in batch] + # We cast because if the first element is a string, all elements in the batch are strings + new_batch[key] = cast(List[str], [sample[key] for sample in batch]) else: + # Otherwise all elements in the batch are tensors new_batch[key] = torch.stack([torch.tensor(sample[key]) for sample in batch]) return new_batch diff --git a/dalm/models/__init__.py b/dalm/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dalm/py.typed b/dalm/py.typed index 8b13789..e69de29 100644 --- a/dalm/py.typed +++ b/dalm/py.typed @@ -1 +0,0 @@ - diff --git a/dalm/training/rag_e2e/README.md b/dalm/training/rag_e2e/README.md index 343b64d..b43240a 100644 --- a/dalm/training/rag_e2e/README.md +++ b/dalm/training/rag_e2e/README.md @@ -21,12 +21,27 @@ Before you can execute the code, please make sure you have the following compone - Execute the following script with your desired configuration: -bash -Copy code +To train the model using a smaller retriever, and gpt2, you can run + + +```shell +dalm train-rag-e2e \ +"./dataset" \ +"BAAI/bge-small-en" \ +"gpt2" \ +--output-dir "./rag_e2e_checkpoints" \ +--with-tracking \ +--report-to tensorboard +``` +or, to execute directly, +```shell python train_rag_e2e.py --dataset_path "./dataset" \ --retriever_name_or_path "BAAI/bge-small-en" \ --generator_name_or_path "gpt2" \ --output_dir "./rag_e2e_checkpoints" \ - --with_tracking --report_to tensorboard + --with_tracking \ + --report_to tensorboard +``` + This script will start training the End2End Differentiable RAG model using the specified dataset and model configurations. diff --git a/dalm/training/rag_e2e/requirements.txt b/dalm/training/rag_e2e/requirements.txt deleted file mode 100644 index 2c3ab1a..0000000 --- a/dalm/training/rag_e2e/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -accelerate -transformers -torch -peft -datasets -evaluate -diffusers -bitsandbytes -torchvision \ No newline at end of file diff --git a/dalm/training/rag_e2e/train_rage2e.py b/dalm/training/rag_e2e/train_rage2e.py index b8fbcac..5485a9d 100644 --- a/dalm/training/rag_e2e/train_rage2e.py +++ b/dalm/training/rag_e2e/train_rage2e.py @@ -13,18 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys - -sys.path.append(os.getcwd()) # This is needed to import modules with absolute paths - -# ruff: noqa: E402 import argparse import logging import math +import os import random from argparse import Namespace -from typing import Dict, Union +from types import NoneType +from typing import Dict, Optional, Union import datasets import torch @@ -33,6 +29,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed +from datasets import Dataset from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( @@ -84,7 +81,8 @@ def parse_args() -> Namespace: type=int, default=128, help=( - "The maximum total passage sequence length after tokenization. Sequences longer than this will be truncated," + "The maximum total passage sequence length after tokenization. " + "Sequences longer than this will be truncated," ), ) parser.add_argument( @@ -92,7 +90,8 @@ def parse_args() -> Namespace: type=int, default=256, help=( - "The maximum total generator input sequence length after tokenization. Sequences longer than this will be truncated," + "The maximum total generator input sequence length after tokenization. " + "Sequences longer than this will be truncated," ), ) parser.add_argument( @@ -194,8 +193,8 @@ def parse_args() -> Namespace: type=str, default="all", help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`, `"wandb"`, ' + '`"mlflow"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' "Only applicable when `--with_tracking` is passed." ), ) @@ -214,15 +213,46 @@ def parse_args() -> Namespace: return args -def main() -> None: - args = parse_args() - +def train_e2e( + dataset_or_path: str | Dataset, + retriever_name_or_path: str, + generator_name_or_path: str, + dataset_passage_col_name: str = "Abstract", + dataset_query_col_name: str = "Question", + dataset_answer_col_name: str = "Answer", + query_max_len: int = 50, + passage_max_len: int = 128, + generator_max_len: int = 256, + per_device_train_batch_size: int = 32, + learning_rate: float = 1e-4, + logit_scale: int = 100, + weight_decay: float = 0.0, + num_train_epochs: int = 1, + max_train_steps: Optional[int] = None, + gradient_accumulation_steps: int = 1, + lr_scheduler_type: SchedulerType = SchedulerType.LINEAR, + num_warmup_steps: int = 100, + output_dir: Optional[str] = None, + seed: int = 42, + hub_model_id: Optional[str] = None, + hub_token: Optional[str] = None, + checkpointing_steps: Optional[int | str] = None, + resume_from_checkpoint: Optional[str] = None, + with_tracking: bool = True, + report_to: str = "all", + sanity_test: bool = True, + use_peft: bool = True, +) -> None: + """Train an in-domain model e2e with a retriever and generator. See `dalm train-rag-e2e --help` for more details""" + # Get the passed in vars before beginning training, in case we report training + args = dict(locals()) + # TensorBoard cannot log Enums, need the raw value + args["lr_scheduler_type"] = args["lr_scheduler_type"].value + args = {k: v for k, v in args.items() if v is None or isinstance(v, (float, int, str, NoneType))} # rag retriver and the generator - rag_model = AutoModelForRagE2E(args.retriever_name_or_path, args.generator_name_or_path) + rag_model = AutoModelForRagE2E(retriever_name_or_path, generator_name_or_path) - accelerator = ( - Accelerator(log_with=args.report_to, project_dir=args.output_dir) if args.with_tracking else Accelerator() - ) + accelerator = Accelerator(log_with=report_to, project_dir=output_dir) if with_tracking else Accelerator() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -239,20 +269,22 @@ def main() -> None: transformers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) + if seed is not None: + set_seed(seed) # Handle the repository creation if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) accelerator.wait_for_everyone() # dataset download and preprocessing dataset = ( - datasets.load_from_disk(args.dataset_path) - if os.path.isdir(args.dataset_path) - else datasets.load_dataset("csv", data_files=args.dataset_path)["train"] + dataset_or_path + if isinstance(dataset_or_path, Dataset) + else datasets.load_from_disk(dataset_or_path) + if os.path.isdir(dataset_or_path) + else datasets.load_dataset("csv", data_files=dataset_or_path)["train"] ) retriever_tokenizer = rag_model.retriever_tokenizer generator_tokenizer = rag_model.generator_tokenizer @@ -263,12 +295,12 @@ def main() -> None: example, retriever_tokenizer=rag_model.retriever_tokenizer, generator_tokenizer=rag_model.generator_tokenizer, - dataset_query_col_name=args.dataset_query_col_name, - dataset_passage_col_name=args.dataset_passage_col_name, - dataset_answer_col_name=args.dataset_answer_col_name, - query_max_len=args.query_max_len, - passage_max_len=args.passage_max_len, - generator_max_len=args.generator_max_len, + dataset_query_col_name=dataset_query_col_name, + dataset_passage_col_name=dataset_passage_col_name, + dataset_answer_col_name=dataset_answer_col_name, + query_max_len=query_max_len, + passage_max_len=passage_max_len, + generator_max_len=generator_max_len, ), batched=True, remove_columns=dataset.column_names, @@ -284,24 +316,24 @@ def main() -> None: processed_datasets, shuffle=True, collate_fn=default_data_collator, - batch_size=args.per_device_train_batch_size, + batch_size=per_device_train_batch_size, pin_memory=True, ) - optimizer = torch.optim.Adam(rag_model.parameters(), lr=args.learning_rate) + optimizer = torch.optim.Adam(rag_model.parameters(), lr=learning_rate) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + if max_train_steps is None: + max_train_steps = num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, + name=lr_scheduler_type, optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, + num_warmup_steps=num_warmup_steps, + num_training_steps=max_train_steps, ) (rag_model, optimizer, train_dataloader, lr_scheduler) = accelerator.prepare( @@ -309,50 +341,48 @@ def main() -> None: ) # We need to recalculate our total training steps as the size of the training dataloader may have changed - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + max_train_steps = num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) # Figure out how many steps we should save the Accelerator states - checkpointing_steps = args.checkpointing_steps - if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = checkpointing_steps + if checkpointing_steps is not None and str(checkpointing_steps).isdigit(): checkpointing_steps = int(checkpointing_steps) # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if args.with_tracking: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + # The trackers initialize automatically on the main process. + if with_tracking: + experiment_config = args.copy() accelerator.init_trackers("peft_rag_e2e_learning", experiment_config) - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps - if args.use_peft: + if use_peft: # saving and loading checkpoints for resuming training accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) logger.info("***** Running training *****") logger.info(f" Num examples = {len(processed_datasets)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") + logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 starting_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) + if resume_from_checkpoint: + if resume_from_checkpoint is not None or resume_from_checkpoint != "": + accelerator.print(f"Resumed from checkpoint: {resume_from_checkpoint}") + accelerator.load_state(resume_from_checkpoint) + path = os.path.basename(resume_from_checkpoint) else: # Get the most recent checkpoint dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] @@ -367,18 +397,18 @@ def main() -> None: completed_steps = starting_epoch * num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps - resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + resume_step = int(training_difference.replace("step_", "")) * gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step // args.gradient_accumulation_step + completed_steps = resume_step // gradient_accumulation_steps # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) - for epoch in range(starting_epoch, args.num_train_epochs): + for epoch in range(starting_epoch, num_train_epochs): rag_model.train() total_loss: Union[float, torch.Tensor] = 0.0 - if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + if resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) else: @@ -396,7 +426,7 @@ def main() -> None: batch["retriever_passage_attention_mask"], ) - logits = get_cosine_sim(query_embs, passage_embs, args.logit_scale) + logits = get_cosine_sim(query_embs, passage_embs, logit_scale) loss_query = get_nt_xent_loss(logits) loss_passage = get_nt_xent_loss(logits.t()) @@ -437,36 +467,36 @@ def main() -> None: completed_steps += 1 if (step + 1) % 100 == 0: - logger.info(f"Step: {step+1}, Loss: {total_loss/(step+1)}") - if args.with_tracking: + logger.info(f"Step: {step + 1}, Loss: {total_loss / (step + 1)}") + if with_tracking: accelerator.log({"train/loss": total_loss / (step + 1)}, step=completed_steps) if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps }" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) + step_output_dir = f"step_{completed_steps}" + if output_dir is not None: + output_dir = os.path.join(output_dir, step_output_dir) accelerator.save_state(output_dir) - if completed_steps >= args.max_train_steps: + if completed_steps >= max_train_steps: break result: Dict[str, Union[int, float, torch.Tensor]] = {} # Use accelerator.print to print only on the main process. accelerator.print(f"epoch {epoch}:", result) - if args.with_tracking: + if with_tracking: step_loss = total_loss.item() if isinstance(total_loss, torch.Tensor) else total_loss result["train/epoch_loss"] = step_loss / len(train_dataloader) accelerator.log(result, step=completed_steps) - if args.output_dir is not None: + if output_dir is not None: accelerator.wait_for_everyone() if accelerator.is_main_process: if isinstance(checkpointing_steps, str): - accelerator.save_state(os.path.join(args.output_dir, f"epoch_{epoch}")) + accelerator.save_state(os.path.join(output_dir, f"epoch_{epoch}")) - retriever_ckpt_path = args.output_dir + "/retriever" - generator_ckpt_path = args.output_dir + "/generator" + retriever_ckpt_path = output_dir + "/retriever" + generator_ckpt_path = output_dir + "/generator" # retriever saving unwrapped_rag_model = accelerator.unwrap_model(rag_model) @@ -486,6 +516,40 @@ def main() -> None: accelerator.end_training() +def main() -> None: + args = parse_args() + train_e2e( + dataset_or_path=args.dataset_path, + retriever_name_or_path=args.retriever_name_or_path, + generator_name_or_path=args.generator_name_or_path, + dataset_passage_col_name=args.dataset_passage_col_name, + dataset_query_col_name=args.dataset_query_col_name, + dataset_answer_col_name=args.dataset_answer_col_name, + query_max_len=args.query_max_len, + passage_max_len=args.passage_max_len, + generator_max_len=args.generator_max_len, + per_device_train_batch_size=args.per_device_train_batch_size, + learning_rate=args.learning_rate, + logit_scale=args.logit_scale, + weight_decay=args.weight_decay, + num_train_epochs=args.num_train_epochs, + max_train_steps=args.max_train_steps, + gradient_accumulation_steps=args.gradient_accumulation_steps, + lr_scheduler_type=args.lr_scheduler_type, + num_warmup_steps=args.num_warmup_steps, + output_dir=args.output_dir, + seed=args.seed, + hub_model_id=args.hub_model_id, + hub_token=args.hub_token, + checkpointing_steps=args.checkpointing_steps, + resume_from_checkpoint=args.resume_from_checkpoint, + with_tracking=args.with_tracking, + report_to=args.report_to, + sanity_test=args.sanity_test, + use_peft=args.use_peft, + ) + + if __name__ == "__main__": main() diff --git a/dalm/training/retriever_only/README.md b/dalm/training/retriever_only/README.md index 7123a86..1c62b6b 100644 --- a/dalm/training/retriever_only/README.md +++ b/dalm/training/retriever_only/README.md @@ -1 +1,2 @@ # arcee-retriever + diff --git a/dalm/training/retriever_only/train_retriever_only.py b/dalm/training/retriever_only/train_retriever_only.py index cec471c..098851d 100644 --- a/dalm/training/retriever_only/train_retriever_only.py +++ b/dalm/training/retriever_only/train_retriever_only.py @@ -13,18 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys - -sys.path.append(os.getcwd()) # This is needed to import modules with absolute paths -# ruff: noqa: E402 - import argparse import logging import math +import os import random from argparse import Namespace -from typing import Dict, Union +from types import NoneType +from typing import Dict, Optional, Union import datasets import torch @@ -32,7 +28,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from black import Union +from datasets import Dataset, DatasetDict from torch.utils.data import DataLoader from tqdm import tqdm from transformers import SchedulerType, default_data_collator, get_scheduler @@ -67,7 +63,8 @@ def parse_args() -> Namespace: type=int, default=128, help=( - "The maximum total passage sequence length after tokenization. Sequences longer than this will be truncated," + "The maximum total passage sequence length after tokenization. " + "Sequences longer than this will be truncated," ), ) parser.add_argument( @@ -166,11 +163,40 @@ def parse_args() -> Namespace: return args -def main() -> None: - args = parse_args() - accelerator = ( - Accelerator(log_with=args.report_to, project_dir=args.output_dir) if args.with_tracking else Accelerator() - ) +def train_retriever( + model_name_or_path: str, + train_dataset_or_csv_path: str | Dataset, + test_dataset_or_csv_path: str | Dataset | None = None, + dataset_passage_col_name: str = "Abstract", + dataset_query_col_name: str = "Question", + query_max_len: int = 50, + passage_max_len: int = 128, + per_device_train_batch_size: int = 32, + learning_rate: float = 1e-4, + logit_scale: int = 100, + weight_decay: float = 0.0, + num_train_epochs: int = 1, + max_train_steps: Optional[int] = None, + gradient_accumulation_steps: int = 1, + lr_scheduler_type: SchedulerType = SchedulerType.LINEAR, + num_warmup_steps: int = 0, + output_dir: Optional[str] = None, + seed: int = 42, + hub_model_id: Optional[str] = None, + hub_token: Optional[str] = None, + checkpointing_steps: Optional[int | str] = None, + resume_from_checkpoint: Optional[str] = None, + with_tracking: bool = True, + report_to: str = "all", + sanity_test: bool = True, + use_peft: bool = True, +) -> None: + # Get the passed in vars before beginning training, in case we report training + args = dict(locals()) + # TensorBoard cannot log Enums, need the raw value + args["lr_scheduler_type"] = args["lr_scheduler_type"].value + args = {k: v for k, v in args.items() if v is None or isinstance(v, (float, int, str, NoneType))} + accelerator = Accelerator(log_with=report_to, project_dir=output_dir) if with_tracking else Accelerator() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -187,32 +213,33 @@ def main() -> None: transformers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) + if seed is not None: + set_seed(seed) # Handle the repository creation if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) accelerator.wait_for_everyone() - model = AutoModelForSentenceEmbedding(args.model_name_or_path, use_bnb=True, get_peft=args.use_peft) + model = AutoModelForSentenceEmbedding(model_name_or_path, use_bnb=True, get_peft=use_peft) tokenizer = model.tokenizer # dataset download and preprocessing - + if isinstance(train_dataset_or_csv_path, Dataset): + dataset = DatasetDict({"train": train_dataset_or_csv_path, "test": test_dataset_or_csv_path}) dataset = datasets.load_dataset( - "csv", data_files={"train": args.train_dataset_csv_path, "validation": args.test_dataset_csv_path} + "csv", data_files={"train": train_dataset_or_csv_path, "validation": test_dataset_or_csv_path} ) processed_datasets = dataset.map( lambda example: preprocess_dataset( example, tokenizer, - query_col_name=args.dataset_query_col_name, - passage_col_name=args.dataset_passage_col_name, - query_max_len=args.query_max_len, - passage_max_len=args.passage_max_len, + query_col_name=dataset_query_col_name, + passage_col_name=dataset_passage_col_name, + query_max_len=query_max_len, + passage_max_len=passage_max_len, ), batched=True, remove_columns=dataset["train"].column_names, @@ -224,7 +251,6 @@ def main() -> None: logger.info(f"Sample {index} of the training set: {processed_datasets['train'][index]}.") model.print_trainable_parameters() # type: ignore # No idea what mypy is complaining about. - accelerator.print(model) # get dataloaders @@ -232,23 +258,23 @@ def main() -> None: processed_datasets["train"], shuffle=True, collate_fn=default_data_collator, - batch_size=args.per_device_train_batch_size, + batch_size=per_device_train_batch_size, pin_memory=True, ) - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + if max_train_steps is None: + max_train_steps = num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, + name=lr_scheduler_type, optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, + num_warmup_steps=num_warmup_steps, + num_training_steps=max_train_steps, ) # Prepare everything with our `accelerator`. @@ -257,50 +283,50 @@ def main() -> None: ) # We need to recalculate our total training steps as the size of the training dataloader may have changed - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + max_train_steps = num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) # Figure out how many steps we should save the Accelerator states - checkpointing_steps = args.checkpointing_steps - if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = checkpointing_steps + if checkpointing_steps is not None and str(checkpointing_steps).isdigit(): checkpointing_steps = int(checkpointing_steps) # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if args.with_tracking: + # The trackers initialize automatically on the main process. + if with_tracking: experiment_config = vars(args) # TensorBoard cannot log Enums, need the raw value experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value accelerator.init_trackers("peft_contrastive_learning", experiment_config) - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps - if args.use_peft: + if use_peft: # saving and loading checkpoints for resuming training accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) logger.info("***** Running training *****") logger.info(f" Num examples = {len(processed_datasets['train'])}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") + logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 starting_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) + if resume_from_checkpoint: + if resume_from_checkpoint is not None or resume_from_checkpoint != "": + accelerator.print(f"Resumed from checkpoint: {resume_from_checkpoint}") + accelerator.load_state(resume_from_checkpoint) + path = os.path.basename(resume_from_checkpoint) else: # Get the most recent checkpoint dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] @@ -315,18 +341,18 @@ def main() -> None: completed_steps = starting_epoch * num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps - resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + resume_step = int(training_difference.replace("step_", "")) * gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step // args.gradient_accumulation_step + completed_steps = resume_step // gradient_accumulation_steps # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) - for epoch in range(starting_epoch, args.num_train_epochs): + for epoch in range(starting_epoch, num_train_epochs): model.train() total_loss: Union[float, torch.Tensor] = 0.0 - if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + if resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) else: @@ -335,7 +361,7 @@ def main() -> None: with accelerator.accumulate(model): query_embs = model(**{k.replace("query_", ""): v for k, v in batch.items() if "query" in k}) passage_embs = model(**{k.replace("passage_", ""): v for k, v in batch.items() if "passage" in k}) - logits = get_cosine_sim(query_embs, passage_embs, args.logit_scale) + logits = get_cosine_sim(query_embs, passage_embs, logit_scale) loss_query = get_nt_xent_loss(logits) loss_passage = get_nt_xent_loss(logits.t()) @@ -353,41 +379,73 @@ def main() -> None: completed_steps += 1 if (step + 1) % 100 == 0: - logger.info(f"Step: {step+1}, Loss: {total_loss/(step+1)}") - if args.with_tracking: + logger.info(f"Step: {step + 1}, Loss: {total_loss / (step + 1)}") + if with_tracking: accelerator.log({"train/loss": total_loss / (step + 1)}, step=completed_steps) if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps }" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) + step_output_dir = f"step_{completed_steps}" + if output_dir is not None: + output_dir = os.path.join(output_dir, step_output_dir) accelerator.save_state(output_dir) - if completed_steps >= args.max_train_steps: + if completed_steps >= max_train_steps: break result: Dict[str, Union[int, float, torch.Tensor]] = {} # Use accelerator.print to print only on the main process. accelerator.print(f"epoch {epoch}:", result) - if args.with_tracking: + if with_tracking: step_loss = total_loss.item() if isinstance(total_loss, torch.Tensor) else total_loss result["train/epoch_loss"] = step_loss / len(train_dataloader) accelerator.log(result, step=completed_steps) - if args.output_dir is not None: + if output_dir is not None: accelerator.wait_for_everyone() if accelerator.is_main_process: if isinstance(checkpointing_steps, str): - accelerator.save_state(os.path.join(args.output_dir, f"epoch_{epoch}")) + accelerator.save_state(os.path.join(output_dir, f"epoch_{epoch}")) accelerator.unwrap_model(model).save_pretrained( - args.output_dir, state_dict=accelerator.get_state_dict(accelerator.unwrap_model(model)) + output_dir, state_dict=accelerator.get_state_dict(accelerator.unwrap_model(model)) ) - tokenizer.save_pretrained(args.output_dir) + tokenizer.save_pretrained(output_dir) accelerator.wait_for_everyone() accelerator.end_training() +def main() -> None: + args = parse_args() + train_retriever( + train_dataset_or_csv_path=args.train_dataset_csv_path, + test_dataset_or_csv_path=args.test_dataset_csv_path, + model_name_or_path=args.model_name_or_path, + dataset_passage_col_name=args.dataset_passage_col_name, + dataset_query_col_name=args.dataset_query_col_name, + query_max_len=args.query_max_len, + passage_max_len=args.passage_max_len, + per_device_train_batch_size=args.per_device_train_batch_size, + learning_rate=args.learning_rate, + logit_scale=args.logit_scale, + weight_decay=args.weight_decay, + num_train_epochs=args.num_train_epochs, + max_train_steps=args.max_train_steps, + gradient_accumulation_steps=args.gradient_accumulation_steps, + lr_scheduler_type=args.lr_scheduler_type, + num_warmup_steps=args.num_warmup_steps, + output_dir=args.output_dir, + seed=args.seed, + hub_model_id=args.hub_model_id, + hub_token=args.hub_token, + checkpointing_steps=args.checkpointing_steps, + resume_from_checkpoint=args.resume_from_checkpoint, + with_tracking=args.with_tracking, + report_to=args.report_to, + sanity_test=args.sanity_test, + use_peft=args.use_peft, + ) + + if __name__ == "__main__": main() diff --git a/dalm/training/utils/train_utils.py b/dalm/training/utils/train_utils.py index 5f13d72..4023fb3 100644 --- a/dalm/training/utils/train_utils.py +++ b/dalm/training/utils/train_utils.py @@ -2,17 +2,17 @@ import torch import torch.nn.functional as F -from transformers import AutoModel +from peft import PeftModel -def save_model_hook(models: List[AutoModel], weights: List, output_dir: str) -> None: +def save_model_hook(models: List[PeftModel], weights: List, output_dir: str) -> None: for i, model in enumerate(models): model.save_pretrained(output_dir, state_dict=weights[i]) # make sure to pop weight so that corresponding model is not saved again weights.pop() -def load_model_hook(models: List[AutoModel], input_dir: str) -> None: +def load_model_hook(models: List[PeftModel], input_dir: str) -> None: while len(models) > 0: model = models.pop() # pop models so that they are not loaded again @@ -20,9 +20,7 @@ def load_model_hook(models: List[AutoModel], input_dir: str) -> None: model.load_adapter(input_dir, model.active_adapter, is_trainable=True) -def get_cosine_sim( - query_embs: torch.FloatTensor, passage_embs: torch.FloatTensor, logit_scale: torch.FloatTensor -) -> torch.Tensor: +def get_cosine_sim(query_embs: torch.FloatTensor, passage_embs: torch.FloatTensor, logit_scale: int) -> torch.Tensor: return torch.matmul(query_embs, passage_embs.t()) * logit_scale diff --git a/pyproject.toml b/pyproject.toml index b45d954..f6d61ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,12 @@ dependencies = [ "torchvision", "pydantic", "typer", + "scipy" ] +[project.scripts] +dalm = "dalm.cli:cli" + [tool.hatch.build.targets.wheel.shared-data] "prefix" = "prefix" @@ -95,6 +99,7 @@ addopts = [ "--cov-report=term-missing", "--cov-report=xml", "--cov-report=html", + "--durations=10", ] diff --git a/tasks.py b/tasks.py index 2ebbf85..5c747b5 100644 --- a/tasks.py +++ b/tasks.py @@ -134,6 +134,11 @@ def build(ctx: Context) -> None: Build the package. """ + ctx.run( + "pip install --upgrade build", + pty=True, + echo=True, + ) ctx.run( "python -m build", pty=True, diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..6bd6839 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,8 @@ +import os + +from dalm import __version__ + + +def test_cli_version() -> None: + version_msg = os.popen("dalm version").readlines()[-1].strip() + assert version_msg == f"🐾You are running DALM version: {__version__}"