diff --git a/.github/workflows/workflow.yaml b/.github/workflows/workflow.yaml index e94092c91..828ee59e3 100644 --- a/.github/workflows/workflow.yaml +++ b/.github/workflows/workflow.yaml @@ -152,4 +152,4 @@ jobs: uses: actions/checkout@v3 - name: Start docker compose and exit when tests run through - run: bash run_integrationtests.sh \ No newline at end of file + run: bash run_integrationtests.sh diff --git a/benchmark/README.md b/benchmark/README.md index 725fe4e3c..108d56736 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -11,3 +11,6 @@ In the `mnist` directory, you find files to running experiments on a dynamic ver ### Criteo 1TB Dataset We provide scripts and guidelines for using the Criteo 1TB benchmark data set. The README in the subfolder contains information on how data is downloaded and preprocessed. + +### Wildtime benchmarks +In the `wildtime_benchmarks` directory, you find files to running experiments on datasets belonging to the WildTime suite. \ No newline at end of file diff --git a/benchmark/wildtime_benchmarks/README.md b/benchmark/wildtime_benchmarks/README.md new file mode 100644 index 000000000..8028229db --- /dev/null +++ b/benchmark/wildtime_benchmarks/README.md @@ -0,0 +1,257 @@ +# Wild-time datasets + +In this directory, you can find the files necessary to run experiments using benchmarks from Wild-time. +There are 4 available datasets: **arxiv**, **huffpost**, **FMoW** and **yearbook**. +You can find more details in the [wild time repo](https://github.com/huaxiuyao/Wild-Time) + +## Data Generation +To run the downloading script you need to install the `gdown` library and, just for FMoW, also the `wilds` library. + +The downloading scripts are adapted from `wild-time-data`. +There is a `data_generation_[benchmark].py` script for each available dataset. +Use the `-h` flag to find out more. + + +## Datasets description + +### Yearbook +The goal is to predict the sex given a yearbook picture. +The dataset contains 37189 samples collected from 1930 to 2013. +Since timestamps in Modyn are based on Unix Timestamps (so 0 is 1/1/1970) we have to remap the years to days. +Precisely, the timestamp for pictures from 1930 is 1/1/1970, then 2/1/1970 for the ones taken in 1931 and so forth. +Samples are saved using BinaryFileWrapper by grouping all samples of the same year in one file. + +### FMoW +The goal is to predict land use for example, _park_, _port_, _police station_ and _swimming pool_, given a satellite image. +Due to human activity, satellite imagery changes over time, requiring models that are robust to temporal distribution shifts. +The dataset contains more than 100.000 samples collected from 2002 to 2017. +Every picture is stored separately (in png format and loaded using SingleSampleFileWrapper) and the os timestamp is set accordingly. + +### HuffPost +The goal is to predict the tag of news given headlines. +The dataset contains more than 60k samples collected from 2012 to 2018. +Titles belonging to the same year are grouped into the same CSV file and stored together. +Each year is mapped to a year starting from 1/1/1970. + +### Arxiv +The goal is to predict the paper category (55 classes) given the paper title. +The dataset contains more than 2 million samples collected from 2002 to 2017. +Titles belonging to the same year are grouped into the same CSV file and stored together. +Each year is mapped to a year starting from 1/1/1970. + +## DOWNLOAD UTILS license +Some code relies on the [Wild-Time-Data repository](https://github.com/wistuba/Wild-Time-Data). +The copyright for that code lies at the author of the wild time data, wistuba. +Please find a full copy of the license at the end of this README. + + +## DATASET licenses +We list the licenses for each Wild-Time dataset below: + +- Yearbook: MIT License +- FMoW: [The Functional Map of the World Challenge Public License](https://raw.githubusercontent.com/fMoW/dataset/master/LICENSE) +- Huffpost: CC0: Public Domain +- arXiv: CC0: Public Domain + +## Wild-time data license + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/benchmark/wildtime_benchmarks/benchmark_utils.py b/benchmark/wildtime_benchmarks/benchmark_utils.py new file mode 100644 index 000000000..10cfebd63 --- /dev/null +++ b/benchmark/wildtime_benchmarks/benchmark_utils.py @@ -0,0 +1,52 @@ +import argparse +import logging +import pathlib +from datetime import datetime + +import gdown + +DAY_LENGTH_SECONDS = 24 * 60 * 60 + + +def download_if_not_exists(drive_id: str, destination_dir: str, destination_file_name: str) -> None: + """ + Function to download data from Google Drive. Used for Wild-time based benchmarks. + This function is adapted from wild-time-data's maybe_download + """ + destination_dir = pathlib.Path(destination_dir) + destination = destination_dir / destination_file_name + if destination.exists(): + return + destination_dir.mkdir(parents=True, exist_ok=True) + gdown.download( + url=f"https://drive.google.com/u/0/uc?id={drive_id}&export=download&confirm=pbef", + output=str(destination), + quiet=False, + ) + + +def setup_argparser_wildtime(dataset: str) -> argparse.ArgumentParser: + parser_ = argparse.ArgumentParser(description=f"{dataset} Benchmark Storage Script") + parser_.add_argument( + "--dir", type=pathlib.Path, action="store", help="Path to data directory" + ) + + return parser_ + + +def setup_logger(): + logging.basicConfig( + level=logging.NOTSET, + format="[%(asctime)s] [%(filename)15s:%(lineno)4d] %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d:%H:%M:%S", + ) + return logging.getLogger(__name__) + + +def create_fake_timestamp(year: int, base_year: int) -> int: + timestamp = ((year - base_year) * DAY_LENGTH_SECONDS) + 1 + return timestamp + + +def create_timestamp(year: int, month: int = 1, day: int = 1) -> int: + return int(datetime(year=year, month=month, day=day).timestamp()) diff --git a/benchmark/wildtime_benchmarks/data_generation_arxiv.py b/benchmark/wildtime_benchmarks/data_generation_arxiv.py new file mode 100644 index 000000000..7191d035d --- /dev/null +++ b/benchmark/wildtime_benchmarks/data_generation_arxiv.py @@ -0,0 +1,69 @@ +import os +import pickle + +import torch +from benchmark_utils import create_timestamp, download_if_not_exists, setup_argparser_wildtime, setup_logger +from torch.utils.data import Dataset +from tqdm import tqdm + +logger = setup_logger() + + +def main(): + parser = setup_argparser_wildtime("Arxiv") + args = parser.parse_args() + + logger.info(f"Downloading data to {args.dir}") + ArXivDownloader(args.dir).store_data() + + +class ArXivDownloader(Dataset): + time_steps = [i for i in range(2007, 2023)] + input_dim = 55 + num_classes = 172 + drive_id = "1H5xzHHgXl8GOMonkb6ojye-Y2yIp436V" + file_name = "arxiv.pkl" + + def __getitem__(self, idx): + return self._dataset["title"][idx], torch.LongTensor([self._dataset["category"][idx]])[0] + + def __len__(self): + return len(self._dataset["category"]) + + def __init__(self, data_dir): + super().__init__() + + download_if_not_exists( + drive_id=self.drive_id, + destination_dir=data_dir, + destination_file_name=self.file_name, + ) + datasets = pickle.load(open(os.path.join(data_dir, self.file_name), "rb")) + assert self.time_steps == list(sorted(datasets.keys())) + self._dataset = datasets + self.path = data_dir + + def store_data(self): + for year in tqdm(self._dataset): + # for simplicity, instead of using years we map each day to a year from 1970 + year_timestamp = create_timestamp(year=1970, month=1, day=year-2006) + year_rows = [] + for i in range(len(self._dataset[year][0]["title"])): + text = self._dataset[year][0]["title"][i].replace("\n", " ") + label = self._dataset[year][0]["category"][i] + csv_row = f"{text}\t{label}" + year_rows.append(csv_row) + + # store the year file + text_file = os.path.join(self.path, f"{year}.csv") + with open(text_file, "w", encoding="utf-8") as f: + f.write("\n".join(year_rows)) + + # set timestamp + os.utime(text_file, (year_timestamp, year_timestamp)) + + os.remove(os.path.join(self.path, "arxiv.pkl")) + + +if __name__ == "__main__": + main() diff --git a/benchmark/wildtime_benchmarks/data_generation_fmow.py b/benchmark/wildtime_benchmarks/data_generation_fmow.py new file mode 100644 index 000000000..cfa98c40a --- /dev/null +++ b/benchmark/wildtime_benchmarks/data_generation_fmow.py @@ -0,0 +1,101 @@ +import csv +import os +import pickle +import shutil +from datetime import datetime + +from benchmark_utils import download_if_not_exists, setup_argparser_wildtime, setup_logger +from torch.utils.data import Dataset +from tqdm import tqdm +from wilds import get_dataset + +logger = setup_logger() + + +def main() -> None: + parser = setup_argparser_wildtime("FMoW") + args = parser.parse_args() + + logger.info(f"Downloading data to {args.dir}") + + downloader = FMOWDownloader(args.dir) + downloader.store_data() + downloader.clean_folder() + + +class FMOWDownloader(Dataset): + time_steps = list(range(16)) + input_dim = (3, 224, 224) + num_classes = 62 + drive_id = "1s_xtf2M5EC7vIFhNv_OulxZkNvrVwIm3" + file_name = "fmow.pkl" + + def __init__(self, data_dir: str) -> None: + download_if_not_exists( + drive_id=self.drive_id, + destination_dir=data_dir, + destination_file_name=self.file_name, + ) + datasets = pickle.load(open(os.path.join(data_dir, self.file_name), "rb")) + self._dataset = datasets + try: + self._root = get_dataset(dataset="fmow", root_dir=data_dir, download=True).root + except ValueError: + pass + self.metadata = self.parse_metadata(data_dir) + self.data_dir = data_dir + + def clean_folder(self) -> None: + folder_path = os.path.join(self.data_dir, "fmow_v1.1") + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + + def move_file_and_rename(self, index: int) -> None: + source_dir = os.path.join(self.data_dir, "fmow_v1.1", "images") + if os.path.exists(source_dir) and os.path.isdir(source_dir): + src_file = os.path.join(source_dir, f"rgb_img_{index}.png") + dest_file = os.path.join(self.data_dir, f"rgb_img_{index}.png") + shutil.move(src_file, dest_file) + new_name = os.path.join(self.data_dir, f"{index}.png") + os.rename(dest_file, new_name) + + def store_data(self) -> None: + + for year in tqdm(self._dataset): + split = 0 # just use training split for now + for i in range(len(self._dataset[year][split]["image_idxs"])): + index = self._dataset[year][split]["image_idxs"][i] + label = self._dataset[year][split]["labels"][i] + raw_timestamp = self.metadata[index]["timestamp"] + + if len(raw_timestamp) == 24: + timestamp = datetime.strptime(raw_timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') + else: + timestamp = datetime.strptime(raw_timestamp, '%Y-%m-%dT%H:%M:%SZ') + + # save label + label_file = os.path.join(self.data_dir, f"{index}.label") + with open(label_file, "w", encoding="utf-8") as f: + f.write(str(int(label))) + os.utime(label_file, (timestamp.timestamp(), timestamp.timestamp())) + + # set image timestamp + self.move_file_and_rename(index) + image_file = os.path.join(self.data_dir, f"{index}.png") + os.utime(image_file, (timestamp.timestamp(), timestamp.timestamp())) + + @staticmethod + def parse_metadata(data_dir: str) -> list: + filename = os.path.join(data_dir, "fmow_v1.1", "rgb_metadata.csv") + metadata = [] + + with open(filename, 'r') as file: + csv_reader = csv.reader(file) + next(csv_reader) + for row in csv_reader: + picture_info = {"split": row[0], "timestamp": row[11]} + metadata.append(picture_info) + return metadata + +if __name__ == "__main__": + main() diff --git a/benchmark/wildtime_benchmarks/data_generation_huffpost.py b/benchmark/wildtime_benchmarks/data_generation_huffpost.py new file mode 100644 index 000000000..497829e89 --- /dev/null +++ b/benchmark/wildtime_benchmarks/data_generation_huffpost.py @@ -0,0 +1,69 @@ +import os +import pickle +from datetime import datetime + +import torch +from benchmark_utils import create_timestamp, download_if_not_exists, setup_argparser_wildtime, setup_logger +from torch.utils.data import Dataset +from tqdm import tqdm + +logger = setup_logger() + + +def main(): + parser = setup_argparser_wildtime("Huffpost") + args = parser.parse_args() + + logger.info(f"Downloading data to {args.dir}") + HuffpostDownloader(args.dir).store_data() + + +class HuffpostDownloader(Dataset): + time_steps = [i for i in range(2012, 2019)] + input_dim = 44 + num_classes = 11 + drive_id = "1jKqbfPx69EPK_fjgU9RLuExToUg7rwIY" + file_name = "huffpost.pkl" + + def __getitem__(self, idx): + return self._dataset["title"][idx], torch.LongTensor([self._dataset["category"][idx]])[0] + + def __len__(self): + return len(self._dataset["category"]) + + def __init__(self, data_dir: str): + super().__init__() + + download_if_not_exists( + drive_id=self.drive_id, + destination_dir=data_dir, + destination_file_name=self.file_name, + ) + datasets = pickle.load(open(os.path.join(data_dir, self.file_name), "rb")) + assert self.time_steps == list(sorted(datasets.keys())) + self._dataset = datasets + self.path = data_dir + + def store_data(self) -> None: + for year in tqdm(self._dataset): + year_timestamp = create_timestamp(year=1970, month=1, day=year-2011) + year_rows = [] + for i in range(len(self._dataset[year][0]["headline"])): + text = self._dataset[year][0]["headline"][i] + label = self._dataset[year][0]["category"][i] + csv_row = f"{text}\t{label}" + year_rows.append(csv_row) + + # store the sentences + text_file = os.path.join(self.path, f"{year}.csv") + with open(text_file, "w", encoding="utf-8") as f: + f.write("\n".join(year_rows)) + + # set timestamp + os.utime(text_file, (year_timestamp, year_timestamp)) + + os.remove(os.path.join(self.path, "huffpost.pkl")) + + +if __name__ == "__main__": + main() diff --git a/benchmark/wildtime_benchmarks/data_generation_yearbook.py b/benchmark/wildtime_benchmarks/data_generation_yearbook.py new file mode 100644 index 000000000..01fe9e6f1 --- /dev/null +++ b/benchmark/wildtime_benchmarks/data_generation_yearbook.py @@ -0,0 +1,89 @@ +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from benchmark_utils import create_fake_timestamp, download_if_not_exists, setup_argparser_wildtime, setup_logger +from torch.utils.data import Dataset + +logger = setup_logger() + + +def main(): + parser = setup_argparser_wildtime("Yearbook") + args = parser.parse_args() + + logger.info(f"Downloading data to {args.dir}") + + downloader = YearbookDownloader(args.dir) + downloader.store_data() + + +class YearbookDownloader(Dataset): + time_steps = [i for i in range(1930, 2014)] + input_dim = (1, 32, 32) + num_classes = 2 + drive_id = "1mPpxoX2y2oijOvW1ymiHEYd7oMu2vVRb" + file_name = "yearbook.pkl" + + def __init__(self, data_dir: str): + super().__init__() + download_if_not_exists( + drive_id=self.drive_id, + destination_dir=data_dir, + destination_file_name=self.file_name, + ) + datasets = pickle.load(open(os.path.join(data_dir, self.file_name), "rb")) + self._dataset = datasets + self.data_dir = data_dir + + def _get_year_data(self, year: int) -> list[Tuple]: + images = torch.FloatTensor( + np.array( + [ # transpose to transform from HWC to CHW (H=height, W=width, C=channels). + # Pytorch requires CHW format + img.transpose(2, 0, 1)[0].reshape(*self.input_dim) + # _dataset has 3 dimensions [years][train=0,valid=1]["images"/"labels"] + for img in self._dataset[year][0]["images"] + ] + ) + ) + labels = torch.LongTensor(self._dataset[year][0]["labels"]) + return [(images[i], labels[i]) for i in range(len(images))] + + def __len__(self) -> int: + return len(self._dataset["labels"]) + + def store_data(self) -> None: + # create directories + if not os.path.exists(self.data_dir): + os.mkdir(self.data_dir) + + for year in self.time_steps: + print(f"Saving data for year {year}") + ds = self._get_year_data(year) + self.create_binary_file(ds, + os.path.join(self.data_dir, f"{year}.bin"), + create_fake_timestamp(year, base_year=1930)) + + os.remove(os.path.join(self.data_dir, "yearbook.pkl")) + + @staticmethod + def create_binary_file(data, output_file_name: str, timestamp: int) -> None: + with open(output_file_name, "wb") as f: + for tensor1, tensor2 in data: + features_bytes = tensor1.numpy().tobytes() + label_integer = tensor2.item() + + features_size = len(features_bytes) + assert features_size == 4096 + + f.write(int.to_bytes(label_integer, length=4, byteorder="big")) + f.write(features_bytes) + + os.utime(output_file_name, (timestamp, timestamp)) + + +if __name__ == "__main__": + main() diff --git a/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml b/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml new file mode 100644 index 000000000..0415e6a61 --- /dev/null +++ b/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml @@ -0,0 +1,48 @@ +pipeline: + name: ArXiv dataset Test Pipeline + description: Example pipeline + version: 1.0.0 +model: + id: ArticleNet + config: + num_classes: 172 +training: + gpus: 1 + device: "cuda:0" + dataloader_workers: 2 + use_previous_model: True + initial_model: random + initial_pass: + activated: False + batch_size: 96 + optimizers: + - name: "default" + algorithm: "SGD" + source: "PyTorch" + param_groups: + - module: "model" + config: + lr: 0.00002 + momentum: 0.9 + weight_decay: 0.01 + optimization_criterion: + name: "CrossEntropyLoss" + checkpointing: + activated: False + selection_strategy: + name: NewDataStrategy + maximum_keys_in_memory: 10000 + config: + limit: -1 + reset_after_trigger: True +data: + dataset_id: arxiv + bytes_parser_function: | + def bytes_parser_function(data: bytes) -> str: + return data.decode("utf-8") + tokenizer: DistilBertTokenizerTransform + +trigger: + id: TimeTrigger + trigger_config: + trigger_every: "1d" \ No newline at end of file diff --git a/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml b/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml new file mode 100644 index 000000000..80bd1aa28 --- /dev/null +++ b/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml @@ -0,0 +1,52 @@ +pipeline: + name: FunctionalMapoftheWorld (fmow) Test Pipeline + description: Example pipeline + version: 1.0.0 +model: + id: FmowNet + config: + num_classes: 62 +training: + gpus: 1 + device: "cuda:0" + dataloader_workers: 2 + use_previous_model: True + initial_model: random + initial_pass: + activated: False + batch_size: 64 + optimizers: + - name: "default" + algorithm: "SGD" + source: "PyTorch" + param_groups: + - module: "model" + config: + lr: 0.0001 + momentum: 0.39 + optimization_criterion: + name: "CrossEntropyLoss" + checkpointing: + activated: False + selection_strategy: + name: NewDataStrategy + maximum_keys_in_memory: 1000 + config: + limit: -1 + reset_after_trigger: True +data: + dataset_id: fmow + transformations: [ + "transforms.ToTensor()", + "transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])", + ] + bytes_parser_function: | + from PIL import Image + import io + def bytes_parser_function(data: bytes) -> Image: + return Image.open(io.BytesIO(data)).convert("RGB") +#change and configure it if you want day/week/month granularity +trigger: + id: DataAmountTrigger + trigger_config: + data_points_for_trigger: 2000 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml b/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml new file mode 100644 index 000000000..667522f18 --- /dev/null +++ b/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml @@ -0,0 +1,48 @@ +pipeline: + name: Huffpost dataset Test Pipeline + description: Example pipeline + version: 1.0.0 +model: + id: ArticleNet + config: + num_classes: 55 +training: + gpus: 1 + device: "cuda:0" + dataloader_workers: 2 + use_previous_model: True + initial_model: random + initial_pass: + activated: False + batch_size: 64 + optimizers: + - name: "default" + algorithm: "SGD" + source: "PyTorch" + param_groups: + - module: "model" + config: + lr: 0.00002 + momentum: 0.9 + weight_decay: 0.01 + optimization_criterion: + name: "CrossEntropyLoss" + checkpointing: + activated: False + selection_strategy: + name: NewDataStrategy + maximum_keys_in_memory: 1000 + config: + limit: -1 + reset_after_trigger: True +data: + dataset_id: huffpost + bytes_parser_function: | + def bytes_parser_function(data: bytes) -> str: + return data.decode("utf-8") + tokenizer: DistilBertTokenizerTransform + +trigger: + id: TimeTrigger + trigger_config: + trigger_every: "1d" \ No newline at end of file diff --git a/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml b/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml new file mode 100644 index 000000000..d35dd09b5 --- /dev/null +++ b/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml @@ -0,0 +1,50 @@ +pipeline: + name: Yearbook Test Pipeline + description: Example pipeline + version: 1.0.0 +model: + id: YearbookNet + config: + num_input_channels: 1 + num_classes: 2 +training: + gpus: 1 + device: "cuda:0" + dataloader_workers: 2 + use_previous_model: True + initial_model: random + initial_pass: + activated: False + batch_size: 64 + optimizers: + - name: "default" + algorithm: "SGD" + source: "PyTorch" + param_groups: + - module: "model" + config: + lr: 0.001 + momentum: 0.9 + optimization_criterion: + name: "CrossEntropyLoss" + checkpointing: + activated: False + selection_strategy: + name: NewDataStrategy + maximum_keys_in_memory: 1000 + config: + limit: -1 + reset_after_trigger: True +data: + dataset_id: yearbook + transformations: [] + bytes_parser_function: | + import torch + import numpy as np + def bytes_parser_function(data: bytes) -> torch.Tensor: + return torch.from_numpy(np.frombuffer(data, dtype=np.float32)).reshape(1, 32, 32) + +trigger: + id: TimeTrigger + trigger_config: + trigger_every: "1d" \ No newline at end of file diff --git a/environment.yml b/environment.yml index 586aedbd0..e3e31a4c9 100644 --- a/environment.yml +++ b/environment.yml @@ -11,6 +11,7 @@ channels: - anaconda - nvidia - pytorch + - huggingface dependencies: - python>=3.9 @@ -29,8 +30,9 @@ dependencies: - types-protobuf - types-psycopg2 - types-PyYAML + - transformers - pytorch::pytorch - pytorch::torchvision - pytorch::cpuonly # comment out if commenting in lines below for CUDA # - nvidia::cudatoolkit=11.7 -# - pytorch::pytorch-cuda=11.7 \ No newline at end of file +# - pytorch::pytorch-cuda=11.7 diff --git a/modyn/config/examples/modyn_config.yaml b/modyn/config/examples/modyn_config.yaml index e7c139ffa..406a529da 100644 --- a/modyn/config/examples/modyn_config.yaml +++ b/modyn/config/examples/modyn_config.yaml @@ -45,6 +45,90 @@ storage: ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 2000000, + }, + { + name: "yearbook", + description: "Yearbook Dataset from Wild-Time", + version: "0.0.1", + base_path: "/datasets/yearbook", + filesystem_wrapper_type: "LocalFilesystemWrapper", + file_wrapper_type: "BinaryFileWrapper", + file_wrapper_config: + { + byteorder: "big", + record_size: 4100, + label_size: 4, + file_extension: ".bin" + }, + ignore_last_timestamp: false, + file_watcher_interval: 5, + selector_batch_size: 256, + }, + { + name: "fmow", + description: "Functional Map of the World Dataset (from WILDS/Wild-time)", + version: "0.0.1", + base_path: "/datasets/fmow", + filesystem_wrapper_type: "LocalFilesystemWrapper", + file_wrapper_type: "SingleSampleFileWrapper", + file_wrapper_config: + { + file_extension: ".png", + label_file_extension: ".label" + }, + ignore_last_timestamp: false, + file_watcher_interval: 5, + selector_batch_size: 1024, + }, + { + name: "arxiv", + description: "Arxiv Dataset (from Wild-time)", + version: "0.0.1", + base_path: "/datasets/arxiv", + filesystem_wrapper_type: "LocalFilesystemWrapper", + file_wrapper_type: "SingleSampleFileWrapper", + file_wrapper_config: + { + file_extension: ".txt", + label_file_extension: ".label" + }, + ignore_last_timestamp: false, + file_watcher_interval: 5, + selector_batch_size: 4096, + }, + { + name: "huffpost", + description: "Huffpost Dataset (from Wild-time)", + version: "0.0.1", + base_path: "/datasets/huffpost", + filesystem_wrapper_type: "LocalFilesystemWrapper", + file_wrapper_type: "CsvFileWrapper", + file_wrapper_config: + { + file_extension: ".csv", + separator: "\t", #tsv best option here since headlines contain commas and semicolons + label_index: 1 + }, + ignore_last_timestamp: false, + file_watcher_interval: 5, + selector_batch_size: 4096, + }, + { + name: "arxiv", + description: "Arxiv Dataset (from Wild-time)", + version: "0.0.1", + base_path: "/datasets/arxiv", + filesystem_wrapper_type: "LocalFilesystemWrapper", + file_wrapper_type: "CsvFileWrapper", + file_wrapper_config: + { + file_extension: ".csv", + separator: "\t", #tsv best option here since sentences contain commas and semicolons + label_index: 1 + }, + ignore_last_timestamp: false, + file_watcher_interval: 5, + selector_batch_size: 4096, } ] database: diff --git a/modyn/config/schema/pipeline-schema.yaml b/modyn/config/schema/pipeline-schema.yaml index 701a40323..4a8f624db 100644 --- a/modyn/config/schema/pipeline-schema.yaml +++ b/modyn/config/schema/pipeline-schema.yaml @@ -325,6 +325,10 @@ properties: type: string description: | (Optional) function used to transform the label (tensors of integers). + tokenizer: + type: string + description: | + (Optional) Function to tokenize the input. Must be a class in modyn.models.tokenizers. required: - dataset_id - bytes_parser_function diff --git a/modyn/models/README.md b/modyn/models/README.md index 3a6df3953..07c308bb7 100644 --- a/modyn/models/README.md +++ b/modyn/models/README.md @@ -1,3 +1,8 @@ # Custom models -The user can define models here. The model definition should take as a parameter a 'model_configuration' dictionary with architecture-specific parameters. As an example, see the 'ResNet18' model. \ No newline at end of file +The user can define models here. The model definition should take as a parameter a 'model_configuration' dictionary with architecture-specific parameters. As an example, see the 'ResNet18' model. + +# Wild Time models +The code for the models used for WildTime is taken from the official [repository](https://github.com/huaxiuyao/Wild-Time). +The original version is linked in each class. +You can find [here](https://raw.githubusercontent.com/huaxiuyao/Wild-Time/main/LICENSE) a copy of the MIT license \ No newline at end of file diff --git a/modyn/models/__init__.py b/modyn/models/__init__.py index 4f614f1c6..2786a4ad3 100644 --- a/modyn/models/__init__.py +++ b/modyn/models/__init__.py @@ -3,8 +3,11 @@ """ import os +from .articlenet.articlenet import ArticleNet # noqa: F401 from .dlrm.dlrm import DLRM # noqa: F401 +from .fmownet.fmownet import FmowNet # noqa: F401 from .resnet18.resnet18 import ResNet18 # noqa: F401 +from .yearbooknet.yearbooknet import YearbookNet # noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") diff --git a/modyn/models/articlenet/__init__.py b/modyn/models/articlenet/__init__.py new file mode 100644 index 000000000..443bcf6b1 --- /dev/null +++ b/modyn/models/articlenet/__init__.py @@ -0,0 +1,8 @@ +""" +DistilBert classifier +""" +import os + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/articlenet/articlenet.py b/modyn/models/articlenet/articlenet.py new file mode 100644 index 000000000..c22681a6e --- /dev/null +++ b/modyn/models/articlenet/articlenet.py @@ -0,0 +1,54 @@ +from typing import Any + +import torch +from torch import nn +from transformers import DistilBertModel + + +class ArticleNet: + """ + Adapted from WildTime. This network is used for NLP tasks (Arxiv and Huffpost) + Here you can find the original implementation: + https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/networks/article.py + """ + + # pylint: disable-next=unused-argument + def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: + self.model = ArticleNetwork(**model_configuration) + self.model.to(device) + + +class DistilBertFeaturizer(DistilBertModel): + def __init__(self, config: Any) -> None: + super().__init__(config) + self.d_out = config.hidden_size + + def __call__(self, data: torch.Tensor) -> torch.Tensor: + # slice the input tensor to get input ids and attention mask + # The model receives as input the output of the tokenizer, where the first dimension + # contains the tokens and the second a boolean mask to indicate which tokens are valid + # (the sentences have different lengths but the output of the tokenizer has always the same size, + # so you need the mask to understand what is useful data and what is just padding) + input_ids = data[:, :, 0] + attention_mask = data[:, :, 1] + # DistilBert's forward pass + hidden_state = super().__call__( + input_ids=input_ids, + attention_mask=attention_mask, + )[ + 0 + ] # 0: last hidden state, 1: hiddent states, 2: attentions + pooled_output = hidden_state[:, 0] # first token is the pooled output, which is the aggregated representation + # of the entire input sequence + return pooled_output + + +class ArticleNetwork(nn.Module): + def __init__(self, num_classes: int) -> None: + super().__init__() + self.featurizer = DistilBertFeaturizer.from_pretrained("distilbert-base-uncased") + self.classifier = nn.Linear(self.featurizer.d_out, num_classes) + + def forward(self, data: torch.Tensor) -> torch.Tensor: + embedding = self.featurizer(data) + return self.classifier(embedding) diff --git a/modyn/models/fmownet/__init__.py b/modyn/models/fmownet/__init__.py new file mode 100644 index 000000000..6466cc57f --- /dev/null +++ b/modyn/models/fmownet/__init__.py @@ -0,0 +1,8 @@ +""" +Densenet121 + linear layer +""" +import os + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/fmownet/fmownet.py b/modyn/models/fmownet/fmownet.py new file mode 100644 index 000000000..1be61a5b0 --- /dev/null +++ b/modyn/models/fmownet/fmownet.py @@ -0,0 +1,35 @@ +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn +from torchvision.models import densenet121 + + +class FmowNet: + """ + Adapted from WildTime. + Here you can find the original implementation: + https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/networks/fmow.py + """ + + # pylint: disable-next=unused-argument + def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: + self.model = FmowNetModel(**model_configuration) + self.model.to(device) + + +class FmowNetModel(nn.Module): + def __init__(self, num_classes: int) -> None: + super().__init__() + self.num_classes = num_classes + self.enc = densenet121(pretrained=True).features + self.classifier = nn.Linear(1024, self.num_classes) + + def forward(self, data: torch.Tensor) -> torch.Tensor: + features = self.enc(data) + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out = torch.flatten(out, 1) + + return self.classifier(out) diff --git a/modyn/models/tokenizers/__init__.py b/modyn/models/tokenizers/__init__.py new file mode 100644 index 000000000..f75e9a15a --- /dev/null +++ b/modyn/models/tokenizers/__init__.py @@ -0,0 +1,10 @@ +""" +Bert Tokenizer for NLP tasks +""" +import os + +from .distill_bert_tokenizer import DistilBertTokenizerTransform # noqa: F401 + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/tokenizers/distill_bert_tokenizer.py b/modyn/models/tokenizers/distill_bert_tokenizer.py new file mode 100644 index 000000000..ece976992 --- /dev/null +++ b/modyn/models/tokenizers/distill_bert_tokenizer.py @@ -0,0 +1,24 @@ +import torch +from transformers import DistilBertTokenizer + + +class DistilBertTokenizerTransform: + """ + Adapted from WildTime's initialize_distilbert_transform + Here you can find the original implementation: + https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/data/utils.py + """ + + def __init__(self, max_token_length: int = 300) -> None: + self.max_token_length = max_token_length + self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") + + def __call__(self, sample: str) -> torch.Tensor: + # make the class Callable to use it as Torch Transform + tokens = self.tokenizer( + sample, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" + ) + # create a tensor whose first dimension is the input_ids and the second is the attention_mask + data = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) + data = torch.squeeze(data, dim=0) # First shape dim is always 1, since the input is just one string + return data diff --git a/modyn/models/yearbooknet/__init__.py b/modyn/models/yearbooknet/__init__.py new file mode 100644 index 000000000..46776015e --- /dev/null +++ b/modyn/models/yearbooknet/__init__.py @@ -0,0 +1,8 @@ +""" +Custom CNN for Yearbook dataset +""" +import os + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/yearbooknet/yearbooknet.py b/modyn/models/yearbooknet/yearbooknet.py new file mode 100644 index 000000000..a0f1b4dc6 --- /dev/null +++ b/modyn/models/yearbooknet/yearbooknet.py @@ -0,0 +1,41 @@ +from typing import Any + +import torch +from torch import nn + + +class YearbookNet: + """ + Adapted from WildTime. + Here you can find the original implementation: + https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/networks/yearbook.py + """ + + # pylint: disable-next=unused-argument + def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: + self.model = YearbookNetModel(**model_configuration) + self.model.to(device) + + +class YearbookNetModel(nn.Module): + def __init__(self, num_input_channels: int, num_classes: int) -> None: + super().__init__() + self.enc = nn.Sequential( + self.conv_block(num_input_channels, 32), + self.conv_block(32, 32), + self.conv_block(32, 32), + self.conv_block(32, 32), + ) + self.hid_dim = 32 + self.classifier = nn.Linear(32, num_classes) + + def conv_block(self, in_channels: int, out_channels: int) -> nn.Module: + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.MaxPool2d(2) + ) + + def forward(self, data: torch.Tensor) -> torch.Tensor: + data = self.enc(data) + data = torch.mean(data, dim=(2, 3)) + + return self.classifier(data) diff --git a/modyn/protos/trainer_server.proto b/modyn/protos/trainer_server.proto index 800bf3d9a..bafda2a4f 100644 --- a/modyn/protos/trainer_server.proto +++ b/modyn/protos/trainer_server.proto @@ -54,6 +54,7 @@ message StartTrainingRequest { JsonString grad_scaler_configuration = 20; int32 epochs_per_trigger = 21; optional int32 seed = 22; + optional PythonString tokenizer = 23; } message StartTrainingResponse { diff --git a/modyn/selector/internal/selector_manager.py b/modyn/selector/internal/selector_manager.py index a325d9b15..be15cdf45 100644 --- a/modyn/selector/internal/selector_manager.py +++ b/modyn/selector/internal/selector_manager.py @@ -177,6 +177,6 @@ def cleanup_trigger_samples(self) -> None: "cleanup_trigger_samples_after_shutdown" in self._modyn_config["selector"] and "trigger_sample_directory" in self._modyn_config["selector"] ): - shutil.rmtree(self._modyn_config["selector"]["cleanup_trigger_samples_after_shutdown"]) + shutil.rmtree(self._modyn_config["selector"]["trigger_sample_directory"]) Path(self._modyn_config["selector"]["trigger_sample_directory"]).mkdir(parents=True, exist_ok=True) logger.info("Deleted the trigger sample directory.") diff --git a/modyn/storage/internal/grpc/storage_grpc_servicer.py b/modyn/storage/internal/grpc/storage_grpc_servicer.py index ed8083c5d..219eb5c65 100644 --- a/modyn/storage/internal/grpc/storage_grpc_servicer.py +++ b/modyn/storage/internal/grpc/storage_grpc_servicer.py @@ -30,8 +30,8 @@ RegisterNewDatasetResponse, ) from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageServicer -from modyn.utils import current_time_millis, get_partition_for_worker -from sqlalchemy import asc, select +from modyn.utils.utils import current_time_millis, get_partition_for_worker +from sqlalchemy import and_, asc, select from sqlalchemy.orm import Session logger = logging.getLogger(__name__) @@ -74,7 +74,10 @@ def Get(self, request: GetRequest, context: grpc.ServicerContext) -> Iterable[Ge return samples: list[Sample] = ( - session.query(Sample).filter(Sample.sample_id.in_(request.keys)).order_by(Sample.file_id).all() + session.query(Sample) + .filter(and_(Sample.sample_id.in_(request.keys), Sample.dataset_id == dataset.dataset_id)) + .order_by(Sample.file_id) + .all() ) if len(samples) == 0: diff --git a/modyn/supervisor/internal/grpc_handler.py b/modyn/supervisor/internal/grpc_handler.py index 0919ef8b1..0eac4daee 100644 --- a/modyn/supervisor/internal/grpc_handler.py +++ b/modyn/supervisor/internal/grpc_handler.py @@ -292,6 +292,11 @@ def start_training( else: seed = None + if "tokenizer" in pipeline_config["data"]: + tokenizer = pipeline_config["data"]["tokenizer"] + else: + tokenizer = None + if "transformations" in pipeline_config["data"]: transform_list = pipeline_config["data"]["transformations"] else: @@ -349,6 +354,7 @@ def start_training( "grad_scaler_configuration": TrainerServerJsonString(value=json.dumps(grad_scaler_config)), "epochs_per_trigger": epochs_per_trigger, "seed": seed, + "tokenizer": PythonString(value=tokenizer) if tokenizer is not None else None, } cleaned_kwargs = {k: v for k, v in start_training_kwargs.items() if v is not None} diff --git a/modyn/tests/models/test_bert_tokenizer.py b/modyn/tests/models/test_bert_tokenizer.py new file mode 100644 index 000000000..7cc676d8c --- /dev/null +++ b/modyn/tests/models/test_bert_tokenizer.py @@ -0,0 +1,46 @@ +import torch +from modyn.models.tokenizers import DistilBertTokenizerTransform + + +def test_distil_bert_tokenizer_transform(): + max_token_length = 40 + transform = DistilBertTokenizerTransform(max_token_length=max_token_length) + + # Test input string + input_text = "This is a test sentence." + + # Expected output tensors + expected_input_ids = torch.tensor([[101, 2023, 2003, 1037, 3231, 6251, 1012, 102] + [0] * 32], dtype=torch.int64) + expected_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1] + [0] * 32], dtype=torch.int64) + + # Call the transform on the input text + output = transform(input_text) + + # Check if the output has the correct shape and data type + assert output.shape == (max_token_length, 2) + assert output.dtype == torch.int64 + + # Check if the output contains the expected input_ids and attention_mask tensors + assert torch.all(torch.eq(output[:, 0], expected_input_ids)) + assert torch.all(torch.eq(output[:, 1], expected_attention_mask)) + + +def test_distil_bert_tokenizer_transform_empty_input(): + max_token_length = 300 + transform = DistilBertTokenizerTransform(max_token_length=max_token_length) + + # Test empty input string + input_text = "" + + # Call the transform on the empty input text + output = transform(input_text) + + # Check if the output has the correct shape and data type + assert output.shape == (max_token_length, 2) + assert output.dtype == torch.int64 + + # Check if the output contains only the [CLS] and [SEP] tokens + expected_input_ids = torch.tensor([[101, 102] + [0] * 298], dtype=torch.int64) + expected_attention_mask = torch.tensor([[1, 1] + [0] * 298], dtype=torch.int64) + assert torch.all(torch.eq(output[:, 0], expected_input_ids)) + assert torch.all(torch.eq(output[:, 1], expected_attention_mask)) diff --git a/modyn/tests/models/test_yearbook_net.py b/modyn/tests/models/test_yearbook_net.py new file mode 100644 index 000000000..a21526c43 --- /dev/null +++ b/modyn/tests/models/test_yearbook_net.py @@ -0,0 +1,44 @@ +import torch +from modyn.models import YearbookNet + + +def get_model(): + # Create an instance of the YearbookNetModel with the desired parameters for testing + num_input_channels = 3 + num_classes = 10 + return YearbookNet({"num_input_channels": num_input_channels, "num_classes": num_classes}, "cpu", False) + + +def test_model_forward_pass(): + # Create a random input tensor with the appropriate shape + batch_size = 16 + height = 32 + width = 32 + input_channels = 3 + num_classes = 10 + input_data = torch.randn(batch_size, input_channels, height, width) + + # Perform a forward pass through the model + model = get_model() + output = model.model(input_data) + + # Assert that the output has the correct shape + assert output.shape == (batch_size, num_classes) + + +def test_model_conv_block(): + # Test the conv_block method of the model + + # Create a random input tensor with the appropriate shape + batch_size = 16 + height = 32 + width = 32 + input_channels = 3 + input_data = torch.randn(batch_size, input_channels, height, width) + + model = get_model() + # Get the output of the conv_block method + output = model.model.conv_block(input_channels, 32)(input_data) + + # Assert that the output has the correct shape + assert output.shape == (batch_size, 32, height // 2, width // 2) diff --git a/modyn/tests/trainer_server/internal/data/test_data_utils.py b/modyn/tests/trainer_server/internal/data/test_data_utils.py index 28073a447..60dc797ec 100644 --- a/modyn/tests/trainer_server/internal/data/test_data_utils.py +++ b/modyn/tests/trainer_server/internal/data/test_data_utils.py @@ -29,7 +29,7 @@ def noop_constructor_mock(self, channel: grpc.Channel) -> None: def test_prepare_dataloaders( test_weights, test_insecure_channel, test_grpc_connection_established, test_grpc_connection_established_selector ): - train_dataloader, _ = prepare_dataloaders(1, 1, "MNIST", 4, 128, get_mock_bytes_parser(), [], "", "", 42) + train_dataloader, _ = prepare_dataloaders(1, 1, "MNIST", 4, 128, get_mock_bytes_parser(), [], "", "", 42, None) assert train_dataloader.num_workers == 4 assert train_dataloader.batch_size == 128 diff --git a/modyn/tests/trainer_server/internal/data/test_online_dataset.py b/modyn/tests/trainer_server/internal/data/test_online_dataset.py index 54aaaff2e..0ade01f83 100644 --- a/modyn/tests/trainer_server/internal/data/test_online_dataset.py +++ b/modyn/tests/trainer_server/internal/data/test_online_dataset.py @@ -65,6 +65,7 @@ def test_invalid_bytes_parser(test_weights, test_grpc_connection_established): storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, )._init_transforms() with pytest.raises(ValueError): @@ -77,6 +78,7 @@ def test_invalid_bytes_parser(test_weights, test_grpc_connection_established): storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer="", )._init_transforms() @@ -98,6 +100,7 @@ def test_init(test_insecure_channel, test_grpc_connection_established, test_grpc storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) assert online_dataset._pipeline_id == 1 assert online_dataset._trigger_id == 1 @@ -128,6 +131,7 @@ def test_get_keys_and_weights_from_selector( "storage_address": "localhost:1234", "selector_address": "localhost:1234", "training_id": 42, + "tokenizer": None, } online_dataset = OnlineDataset(**kwargs) @@ -160,6 +164,7 @@ def test_get_data_from_storage( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) online_dataset._init_grpc() assert online_dataset._get_data_from_storage(list(range(10))) == ( @@ -217,9 +222,10 @@ def test_deserialize_torchvision_transforms( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) online_dataset._bytes_parser_function = bytes_parser_function - online_dataset._deserialize_torchvision_transforms() + online_dataset._setup_composed_transform() assert isinstance(online_dataset._transform.transforms, list) assert online_dataset._transform.transforms[0].__name__ == "bytes_parser_function" for transform1, transform2 in zip(online_dataset._transform.transforms[1:], transforms_list): @@ -256,6 +262,7 @@ def test_dataset_iter( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataset_iter = iter(online_dataset) all_data = list(dataset_iter) @@ -294,6 +301,7 @@ def test_dataset_iter_with_parsing( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataset_iter = iter(online_dataset) all_data = list(dataset_iter) @@ -332,6 +340,7 @@ def test_dataloader_dataset( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4) for i, batch in enumerate(dataloader): @@ -371,6 +380,7 @@ def test_dataloader_dataset_weighted( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4) for i, batch in enumerate(dataloader): @@ -414,6 +424,7 @@ def test_dataloader_dataset_multi_worker( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4, num_workers=4) for batch in dataloader: @@ -441,6 +452,7 @@ def test_init_grpc(test_insecure_channel, test_grpc_connection_established, test storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) assert online_dataset._storagestub is None @@ -472,12 +484,13 @@ def test_init_transforms( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) assert online_dataset._bytes_parser_function is None assert online_dataset._transform is None - with patch.object(online_dataset, "_deserialize_torchvision_transforms") as tv_ds: + with patch.object(online_dataset, "_setup_composed_transform") as tv_ds: online_dataset._init_transforms() assert online_dataset._bytes_parser_function is not None assert online_dataset._bytes_parser_function(b"\x01") == 1 @@ -533,6 +546,7 @@ def test_iter_multi_partition( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4) @@ -591,6 +605,7 @@ def test_iter_multi_partition_weighted( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4) @@ -651,6 +666,7 @@ def test_iter_multi_partition_cross( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=6) @@ -721,6 +737,7 @@ def test_iter_multi_partition_multi_workers( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4, num_workers=4) idx = 0 @@ -762,6 +779,7 @@ def test_multi_epoch_dataloader_dataset( storage_address="localhost:1234", selector_address="localhost:1234", training_id=42, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4) for _ in range(5): diff --git a/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py b/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py index fcd4e538c..5f554fdd3 100644 --- a/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py +++ b/modyn/tests/trainer_server/internal/data/test_per_class_online_dataset.py @@ -67,6 +67,7 @@ def test_dataloader_dataset( selector_address="localhost:1234", training_id=42, initial_filtered_label=0, + tokenizer=None, ) dataloader = torch.utils.data.DataLoader(online_dataset, batch_size=4) diff --git a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py index 7f57d4b2e..7108e7b97 100644 --- a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py +++ b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py @@ -126,6 +126,7 @@ def mock_get_dataloaders( storage_address, selector_address, training_id, + tokenizer, ): mock_train_dataloader = iter( [(("1",) * 8, torch.ones(8, 10, requires_grad=True), torch.ones(8, dtype=int)) for _ in range(100)] diff --git a/modyn/trainer_server/internal/dataset/data_utils.py b/modyn/trainer_server/internal/dataset/data_utils.py index 62c9318df..59dfbb2ea 100644 --- a/modyn/trainer_server/internal/dataset/data_utils.py +++ b/modyn/trainer_server/internal/dataset/data_utils.py @@ -19,6 +19,7 @@ def prepare_dataloaders( storage_address: str, selector_address: str, training_id: int, + tokenizer: Optional[str], ) -> tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]: """ Gets the proper dataset according to the dataset id, and creates the proper dataloaders. @@ -32,6 +33,7 @@ def prepare_dataloaders( bytes_parser (str): Serialized Python code, used for converting bytes to a form useful for futher transformations (such as Tensors). transform (list[str]): List of serialized torchvision transforms for the samples, before loading. + tokenizer (optional[str]): Optional tokenizer for NLP tasks storage_address (str): Address of the Storage endpoint that the OnlineDataset workers connect to. selector_address (str): Address of the Selector endpoint that the OnlineDataset workers connect to. Returns: @@ -40,7 +42,15 @@ def prepare_dataloaders( """ logger.debug("Creating OnlineDataset.") train_set = OnlineDataset( - pipeline_id, trigger_id, dataset_id, bytes_parser, transform, storage_address, selector_address, training_id + pipeline_id, + trigger_id, + dataset_id, + bytes_parser, + transform, + storage_address, + selector_address, + training_id, + tokenizer, ) logger.debug("Creating DataLoader.") train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_dataloaders) @@ -64,5 +74,6 @@ def prepare_per_class_dataloader_from_online_dataset( online_dataset._selector_address, online_dataset._training_id, initial_filtered_label, + online_dataset._tokenizer_name, ) return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index f6f1bef88..3d48ba5f5 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -14,6 +14,7 @@ MAX_MESSAGE_SIZE, deserialize_function, grpc_connection_established, + instantiate_class, ) from torch.utils.data import IterableDataset, get_worker_info from torchvision import transforms @@ -35,6 +36,7 @@ def __init__( storage_address: str, selector_address: str, training_id: int, + tokenizer: Optional[str], ): self._pipeline_id = pipeline_id self._trigger_id = trigger_id @@ -56,6 +58,12 @@ def __init__( self._key_source = SelectorKeySource(self._pipeline_id, self._trigger_id, self._selector_address) self._uses_weights = None + # tokenizer for NLP tasks + self._tokenizer = None + self._tokenizer_name = tokenizer + if tokenizer is not None: + self._tokenizer = instantiate_class("modyn.models.tokenizers", tokenizer) + logger.debug("Initialized OnlineDataset.") def change_key_source(self, source: AbstractKeySource) -> None: @@ -75,20 +83,24 @@ def _get_data_from_storage(self, selector_keys: list[int]) -> tuple[list[bytes], return sample_list, label_list - def _deserialize_torchvision_transforms(self) -> None: + def _setup_composed_transform(self) -> None: assert self._bytes_parser_function is not None self._transform_list = [self._bytes_parser_function] for transform in self._serialized_transforms: function = eval(transform) # pylint: disable=eval-used self._transform_list.append(function) + + if self._tokenizer is not None: + self._transform_list.append(self._tokenizer) + if len(self._transform_list) > 0: self._transform = transforms.Compose(self._transform_list) def _init_transforms(self) -> None: self._bytes_parser_function = deserialize_function(self._bytes_parser, BYTES_PARSER_FUNC_NAME) self._transform = self._bytes_parser_function - self._deserialize_torchvision_transforms() + self._setup_composed_transform() def _init_grpc(self) -> None: storage_channel = grpc.insecure_channel( @@ -149,14 +161,17 @@ def _unpack_data_tuple(self, data_tuple: Tuple) -> Tuple[int, bytes, int, Option def _get_data_tuple(self, key: int, sample: bytes, label: int, weight: Optional[float]) -> Optional[Tuple]: assert self._uses_weights is not None # mypy complains here because _transform has unknown type, which is ok + tranformed_sample = self._transform(sample) # type: ignore + if self._uses_weights: - return key, self._transform(sample), label, weight # type: ignore - return key, self._transform(sample), label # type: ignore + return key, tranformed_sample, label, weight + return key, tranformed_sample, label def end_of_trigger_cleaning(self) -> None: self._key_source.end_of_trigger_cleaning() # pylint: disable=too-many-locals, too-many-branches + def __iter__(self) -> Generator: worker_info = get_worker_info() if worker_info is None: diff --git a/modyn/trainer_server/internal/dataset/per_class_online_dataset.py b/modyn/trainer_server/internal/dataset/per_class_online_dataset.py index bd715ac17..626d5a4d0 100644 --- a/modyn/trainer_server/internal/dataset/per_class_online_dataset.py +++ b/modyn/trainer_server/internal/dataset/per_class_online_dataset.py @@ -20,6 +20,7 @@ def __init__( selector_address: str, training_id: int, initial_filtered_label: int, + tokenizer: Optional[str], ): super().__init__( pipeline_id, @@ -30,6 +31,7 @@ def __init__( storage_address, selector_address, training_id, + tokenizer, ) assert initial_filtered_label is not None self.filtered_label = initial_filtered_label diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py index 1417e7a29..300ccddf2 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14trainer_server.proto\x12\x07trainer\"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t\"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t\"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05\"\x19\n\x17TrainerAvailableRequest\"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t\"\xfc\x05\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x0b\n\x03\x61mp\x18\x04 \x01(\x08\x12\x10\n\x08model_id\x18\x05 \x01(\t\x12\x30\n\x13model_configuration\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1c\n\x14use_pretrained_model\x18\x07 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x08 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\t \x01(\x05\x12\x12\n\nbatch_size\x18\n \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x0b \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\x0c \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\r \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0e \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0f \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x11 \x03(\t\x12)\n\x0clr_scheduler\x18\x12 \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x13 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x14 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x15 \x01(\x05\x12\x11\n\x04seed\x18\x16 \x01(\x05H\x00\x88\x01\x01\x42\x07\n\x05_seed\"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05\",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"\x84\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12\x16\n\texception\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x07 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\x08 \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\t \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\n \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen\"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05\",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse\"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse\"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse\"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse\"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14trainer_server.proto\x12\x07trainer\"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t\"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t\"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05\"\x19\n\x17TrainerAvailableRequest\"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t\"\xb9\x06\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x0b\n\x03\x61mp\x18\x04 \x01(\x08\x12\x10\n\x08model_id\x18\x05 \x01(\t\x12\x30\n\x13model_configuration\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1c\n\x14use_pretrained_model\x18\x07 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x08 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\t \x01(\x05\x12\x12\n\nbatch_size\x18\n \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x0b \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\x0c \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\r \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0e \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0f \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x11 \x03(\t\x12)\n\x0clr_scheduler\x18\x12 \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x13 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x14 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x15 \x01(\x05\x12\x11\n\x04seed\x18\x16 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x17 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x42\x07\n\x05_seedB\x0c\n\n_tokenizer\"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05\",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"\x84\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12\x16\n\texception\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x07 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\x08 \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\t \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\n \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen\"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05\",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse\"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse\"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse\"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse\"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'trainer_server_pb2', globals()) @@ -34,21 +34,21 @@ _CHECKPOINTINFO._serialized_start=220 _CHECKPOINTINFO._serialized_end=290 _STARTTRAININGREQUEST._serialized_start=293 - _STARTTRAININGREQUEST._serialized_end=1057 - _STARTTRAININGRESPONSE._serialized_start=1059 - _STARTTRAININGRESPONSE._serialized_end=1129 - _TRAININGSTATUSREQUEST._serialized_start=1131 - _TRAININGSTATUSREQUEST._serialized_end=1175 - _TRAININGSTATUSRESPONSE._serialized_start=1178 - _TRAININGSTATUSRESPONSE._serialized_end=1566 - _STOREFINALMODELREQUEST._serialized_start=1568 - _STOREFINALMODELREQUEST._serialized_end=1613 - _STOREFINALMODELRESPONSE._serialized_start=1615 - _STOREFINALMODELRESPONSE._serialized_end=1679 - _GETLATESTMODELREQUEST._serialized_start=1681 - _GETLATESTMODELREQUEST._serialized_end=1725 - _GETLATESTMODELRESPONSE._serialized_start=1727 - _GETLATESTMODELRESPONSE._serialized_end=1792 - _TRAINERSERVER._serialized_start=1795 - _TRAINERSERVER._serialized_end=2252 + _STARTTRAININGREQUEST._serialized_end=1118 + _STARTTRAININGRESPONSE._serialized_start=1120 + _STARTTRAININGRESPONSE._serialized_end=1190 + _TRAININGSTATUSREQUEST._serialized_start=1192 + _TRAININGSTATUSREQUEST._serialized_end=1236 + _TRAININGSTATUSRESPONSE._serialized_start=1239 + _TRAININGSTATUSRESPONSE._serialized_end=1627 + _STOREFINALMODELREQUEST._serialized_start=1629 + _STOREFINALMODELREQUEST._serialized_end=1674 + _STOREFINALMODELRESPONSE._serialized_start=1676 + _STOREFINALMODELRESPONSE._serialized_end=1740 + _GETLATESTMODELREQUEST._serialized_start=1742 + _GETLATESTMODELREQUEST._serialized_end=1786 + _GETLATESTMODELRESPONSE._serialized_start=1788 + _GETLATESTMODELRESPONSE._serialized_end=1853 + _TRAINERSERVER._serialized_start=1856 + _TRAINERSERVER._serialized_end=2313 # @@protoc_insertion_point(module_scope) diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi index fae4b64c3..1ecc8cd7f 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi @@ -134,6 +134,7 @@ class StartTrainingRequest(google.protobuf.message.Message): GRAD_SCALER_CONFIGURATION_FIELD_NUMBER: builtins.int EPOCHS_PER_TRIGGER_FIELD_NUMBER: builtins.int SEED_FIELD_NUMBER: builtins.int + TOKENIZER_FIELD_NUMBER: builtins.int pipeline_id: builtins.int trigger_id: builtins.int device: builtins.str @@ -166,6 +167,8 @@ class StartTrainingRequest(google.protobuf.message.Message): def grad_scaler_configuration(self) -> global___JsonString: ... epochs_per_trigger: builtins.int seed: builtins.int + @property + def tokenizer(self) -> global___PythonString: ... def __init__( self, *, @@ -191,10 +194,14 @@ class StartTrainingRequest(google.protobuf.message.Message): grad_scaler_configuration: global___JsonString | None = ..., epochs_per_trigger: builtins.int = ..., seed: builtins.int | None = ..., + tokenizer: global___PythonString | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["_seed", b"_seed", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "lr_scheduler", b"lr_scheduler", "model_configuration", b"model_configuration", "seed", b"seed", "torch_optimizers_configuration", b"torch_optimizers_configuration"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["_seed", b"_seed", "amp", b"amp", "batch_size", b"batch_size", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "device", b"device", "epochs_per_trigger", b"epochs_per_trigger", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "load_optimizer_state", b"load_optimizer_state", "lr_scheduler", b"lr_scheduler", "model_configuration", b"model_configuration", "model_id", b"model_id", "pipeline_id", b"pipeline_id", "pretrained_model_id", b"pretrained_model_id", "seed", b"seed", "torch_criterion", b"torch_criterion", "torch_optimizers_configuration", b"torch_optimizers_configuration", "transform_list", b"transform_list", "trigger_id", b"trigger_id", "use_pretrained_model", b"use_pretrained_model"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "lr_scheduler", b"lr_scheduler", "model_configuration", b"model_configuration", "seed", b"seed", "tokenizer", b"tokenizer", "torch_optimizers_configuration", b"torch_optimizers_configuration"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "amp", b"amp", "batch_size", b"batch_size", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "device", b"device", "epochs_per_trigger", b"epochs_per_trigger", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "load_optimizer_state", b"load_optimizer_state", "lr_scheduler", b"lr_scheduler", "model_configuration", b"model_configuration", "model_id", b"model_id", "pipeline_id", b"pipeline_id", "pretrained_model_id", b"pretrained_model_id", "seed", b"seed", "tokenizer", b"tokenizer", "torch_criterion", b"torch_criterion", "torch_optimizers_configuration", b"torch_optimizers_configuration", "transform_list", b"transform_list", "trigger_id", b"trigger_id", "use_pretrained_model", b"use_pretrained_model"]) -> None: ... + @typing.overload def WhichOneof(self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]) -> typing_extensions.Literal["seed"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_tokenizer", b"_tokenizer"]) -> typing_extensions.Literal["tokenizer"] | None: ... global___StartTrainingRequest = StartTrainingRequest diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 4df1222c5..a109397dc 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -147,6 +147,7 @@ def __init__( training_info.storage_address, training_info.selector_address, training_info.training_id, + training_info.tokenizer, ) # create callbacks - For now, assume LossCallback by default diff --git a/modyn/trainer_server/internal/utils/training_info.py b/modyn/trainer_server/internal/utils/training_info.py index f3ccd087a..cd71e4726 100644 --- a/modyn/trainer_server/internal/utils/training_info.py +++ b/modyn/trainer_server/internal/utils/training_info.py @@ -71,4 +71,10 @@ def __init__( else: self.seed = None + self.tokenizer: Optional[str] + if request.HasField("tokenizer"): + self.tokenizer = request.tokenizer.value + else: + self.tokenizer = None + self.offline_dataset_path = offline_dataset_path