From c08b24ec53f2361f5e59e8ea86ba68d47f932f70 Mon Sep 17 00:00:00 2001 From: John Staib Matilla Date: Wed, 11 Dec 2024 16:07:46 +0100 Subject: [PATCH] Updated storage and training files to enable text generation tasks. Additionally, made minor updates to the environment and ensured everything passes the standard Python compliance tests, maintaining compatibility with Modyn's existing functionality. --- benchmark/mnist/mnist.yaml | 1 + .../example_pipelines/arxiv.yaml | 1 + .../data_drift_trigger/arxiv_datadrift.yaml | 1 + .../huffpost_datadrift.yaml | 1 + .../yearbook_datadrift.yaml | 1 + .../example_pipelines/fmow.yaml | 2 + .../example_pipelines/huffpost.yaml | 1 + .../example_pipelines/yearbook.yaml | 1 + environment.yml | 17 +- integrationtests/config/dummy.yaml | 1 + integrationtests/config/rho_loss.yaml | 4 +- modyn/common/grpc/grpc_helpers.py | 1 + modyn/config/examples/modyn_config.yaml | 2 +- .../config/schema/pipeline/training/config.py | 5 + modyn/config/schema/system/config.py | 2 +- modyn/models/GPT2/GPT2.py | 365 ++++++ modyn/models/GPT2/GPT2_Model_LoRA.py | 1115 +++++++++++++++++ modyn/models/GPT2/_init_.py | 5 + modyn/models/__init__.py | 1 + modyn/models/tokenizers/__init__.py | 2 +- modyn/models/tokenizers/gpt2_tokenizer.py | 28 + modyn/protos/storage.proto | 9 + modyn/protos/trainer_server.proto | 1 + modyn/storage/include/internal/grpc/draft.hpp | 152 +++ .../internal/grpc/storage_service_impl.hpp | 304 ++++- .../internal/grpc/generated/storage_pb2.py | 97 +- .../internal/grpc/generated/storage_pb2.pyi | 153 +-- .../grpc/generated/storage_pb2_grpc.py | 572 +++++---- .../internal/database/grpc2/storage_pb2.py | 65 + .../internal/database/grpc2/storage_pb2.pyi | 425 +++++++ .../database/grpc2/storage_pb2_grpc.py | 396 ++++++ .../internal/grpc/storage_service_impl.cpp | 5 + .../internal/grpc/supervisor_grpc_servicer.py | 2 - .../pipeline_executor/evaluation_executor.py | 1 + .../grpc/test_trainer_server_grpc_servicer.py | 1 + .../internal/dataset/online_dataset.py | 528 ++++++-- .../grpc/generated/trainer_server_pb2.py | 81 +- .../grpc/generated/trainer_server_pb2.pyi | 211 +--- .../grpc/generated/trainer_server_pb2_grpc.py | 335 ++--- .../internal/trainer/pytorch_trainer.py | 76 +- .../internal/utils/training_info.py | 3 +- 41 files changed, 3933 insertions(+), 1041 deletions(-) create mode 100644 modyn/models/GPT2/GPT2.py create mode 100644 modyn/models/GPT2/GPT2_Model_LoRA.py create mode 100644 modyn/models/GPT2/_init_.py create mode 100644 modyn/models/tokenizers/gpt2_tokenizer.py create mode 100644 modyn/storage/include/internal/grpc/draft.hpp create mode 100644 modyn/storage/src/internal/database/grpc2/storage_pb2.py create mode 100644 modyn/storage/src/internal/database/grpc2/storage_pb2.pyi create mode 100644 modyn/storage/src/internal/database/grpc2/storage_pb2_grpc.py diff --git a/benchmark/mnist/mnist.yaml b/benchmark/mnist/mnist.yaml index ea0765945..3d065c22a 100644 --- a/benchmark/mnist/mnist.yaml +++ b/benchmark/mnist/mnist.yaml @@ -12,6 +12,7 @@ model_storage: training: gpus: 1 device: "cuda:0" + generative: False dataloader_workers: 2 use_previous_model: True initial_model: random diff --git a/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml b/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml index a4c126594..ee6195f91 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml @@ -17,6 +17,7 @@ training: initial_model: random batch_size: 128 shuffle: True + generative: False optimizers: - name: "default" algorithm: "SGD" diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml index 7173bff24..48789b135 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml @@ -13,6 +13,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 96 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml index 4ce54ab12..8e60f7d00 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml @@ -13,6 +13,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 64 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml index e1d1f1aec..336ea8e5a 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml @@ -14,6 +14,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 64 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml b/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml index e376b47ab..a24598377 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml @@ -13,6 +13,8 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False + use_previous_model: True initial_model: random batch_size: 64 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml b/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml index 27eda1029..a93557042 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml @@ -12,6 +12,7 @@ model_storage: training: gpus: 1 device: "cuda:0" + generative: False dataloader_workers: 2 use_previous_model: True initial_model: random diff --git a/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml b/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml index cdb4d4c64..8f441c08e 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml @@ -14,6 +14,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 64 diff --git a/environment.yml b/environment.yml index ca9f223b0..47563adca 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - psycopg2 - sqlalchemy>=2.0 - pyaml - - pydantic + - pydantic==2.9.2 - numpy==1.26.* - pandas - bitstring @@ -43,11 +43,10 @@ dependencies: - nltk - pytorch::pytorch=2.2.1 - pytorch::torchvision - - pytorch::cpuonly # comment out if commenting in lines below for CUDA -# - pytorch::pytorch-cuda=12.1 -# - nvidia::cuda-libraries-dev=12.1.* -# - nvidia::cuda-nvcc=12.1.* -# - nvidia::cuda-nvtx=12.1.* -# - nvidia::cuda-cupti=12.1.* -# - nvidia::cuda-cudart-dev=12.1.* -# - nvidia::cuda-profiler-api=12.1.* + - pytorch::pytorch-cuda=12.1 + - nvidia::cuda-libraries-dev=12.1.* + - nvidia::cuda-nvcc=12.1.* + - nvidia::cuda-nvtx=12.1.* + - nvidia::cuda-cupti=12.1.* + - nvidia::cuda-cudart-dev=12.1.* + - nvidia::cuda-profiler-api==12.1.* diff --git a/integrationtests/config/dummy.yaml b/integrationtests/config/dummy.yaml index f37fe9cae..b9fdd5b6c 100644 --- a/integrationtests/config/dummy.yaml +++ b/integrationtests/config/dummy.yaml @@ -12,6 +12,7 @@ model_storage: training: gpus: 1 device: "cpu" + generative: False dataloader_workers: 1 use_previous_model: True initial_model: random diff --git a/integrationtests/config/rho_loss.yaml b/integrationtests/config/rho_loss.yaml index a30748c44..f2f89e76d 100644 --- a/integrationtests/config/rho_loss.yaml +++ b/integrationtests/config/rho_loss.yaml @@ -13,6 +13,7 @@ training: gpus: 1 device: "cpu" dataloader_workers: 2 + generative: False use_previous_model: False initial_model: random batch_size: 4 @@ -60,6 +61,7 @@ selection_strategy: il_model_config: num_classes: 10 device: "cpu" + generative: False dataloader_workers: 1 use_previous_model: False batch_size: 2 @@ -75,4 +77,4 @@ selection_strategy: lr: 0.1 momentum: 0.001 optimization_criterion: - name: "CrossEntropyLoss" + name: "CrossEntropyLoss" \ No newline at end of file diff --git a/modyn/common/grpc/grpc_helpers.py b/modyn/common/grpc/grpc_helpers.py index f115d0c22..f8a3f2a24 100644 --- a/modyn/common/grpc/grpc_helpers.py +++ b/modyn/common/grpc/grpc_helpers.py @@ -251,6 +251,7 @@ def prepare_start_training_request( enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements, record_loss_every=training_config.record_loss_every, drop_last_batch=training_config.drop_last_batch, + generative=training_config.generative, ) def start_training( diff --git a/modyn/config/examples/modyn_config.yaml b/modyn/config/examples/modyn_config.yaml index 42ec1588e..6ed900978 100644 --- a/modyn/config/examples/modyn_config.yaml +++ b/modyn/config/examples/modyn_config.yaml @@ -278,7 +278,7 @@ selector: local_storage_directory: "/tmp/local_storage" local_storage_max_samples_in_file: 1000000 cleanup_storage_directories_after_shutdown: true - ignore_existing_trigger_samples: false + ignore_existing_trigger_samples: true trainer_server: hostname: "trainer_server" diff --git a/modyn/config/schema/pipeline/training/config.py b/modyn/config/schema/pipeline/training/config.py index b3f673553..98859587d 100644 --- a/modyn/config/schema/pipeline/training/config.py +++ b/modyn/config/schema/pipeline/training/config.py @@ -119,6 +119,11 @@ class TrainingConfig(ModynBaseModel): "we start with random weights. If initial_model is 'pretrained', cannot be False." ) ) + generative: bool = Field(False, + description=( + "If True then, then the training pipeline goes into the generative branch, data is sampled without expecting labels." + ) + ) seed: int | None = Field( None, description=( diff --git a/modyn/config/schema/system/config.py b/modyn/config/schema/system/config.py index 881f0fb01..4827b6e9b 100644 --- a/modyn/config/schema/system/config.py +++ b/modyn/config/schema/system/config.py @@ -255,7 +255,7 @@ class SelectorConfig(HostnamePortMixin): ), ) ignore_existing_trigger_samples: bool = Field( - False, + True, description=( "Whether to ignore existing trigger samples when starting the selector. If set to false, the trigger " "sample directory has to be empty upon startup. May lead to unexpected behaviour if set to true and the " diff --git a/modyn/models/GPT2/GPT2.py b/modyn/models/GPT2/GPT2.py new file mode 100644 index 000000000..2ebf5cd49 --- /dev/null +++ b/modyn/models/GPT2/GPT2.py @@ -0,0 +1,365 @@ +#plimport pytorch_lightning as pl + +from transformers import ( + Adafactor, + GPT2LMHeadModel, + GPT2Tokenizer, +) + +import torch +#from Datasets import CustomDataset, Pretrain_Chunks +from torch.utils.data import RandomSampler +from torch.utils.data import DataLoader, ConcatDataset +from collections import Counter +from typing import Optional, Tuple, Dict, Union, Any, Callable +from typing import Any, Union,List +import re +import string +#from deepspeed.runtime.lr_schedules import WarmupDecayLR +#import deepspeed +import math +import os +import csv +from modyn.models.GPT2.GPT2_Model_LoRA import GPT2LMHeadModel #as GPT2_Lora +#from modyn.models.GPT2.RecAdam import RecAdam +from modyn.models.coreset_methods_support import CoresetSupportingModule + + +class GPT2: + # pylint: disable-next=unused-argument + def __init__(self, hparams: Any, device: str, amp: bool) -> None: + self.model = GPT2Modyn(hparams) + self.model.to(device) + + +# the following class is adapted from +# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py + +class GPT2Modyn(CoresetSupportingModule): + def __init__(self, hparams: Any) -> None: + super(GPT2Modyn, self).__init__() + # self.save_hyperparameters(hparams) + self.unchanged_loss: float = 0.0 + self.changed_loss: float = 0.0 + self.invariant_loss: float = 0.0 + self.unchanged: int = 0 + self.changed: int = 0 + self.invariant: int = 0 + self.validation: int = 0 + self.validation_loss: float = 0.0 + + self.mix_ratio: float = 1.0 + self.mix_decay: float = 0.7 + self.epoch: int = 0 + + self.model = GPT2LMHeadModel.from_pretrained("gpt2-large") # hparams.model_name_or_path + + def freeze_params(self, model: torch.nn.Module) -> None: + for par in model.parameters(): + par.requires_grad = False + + def normalize_answer(self, s: str) -> str: + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text: str) -> str: + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text: str) -> str: + return " ".join(text.split()) + + def remove_punc(text: str) -> str: + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text: str) -> str: + return text.lower() + + def rid_of_specials(text: str) -> str: + text = text.replace("", "") + text = text.replace("", "") + return text + + return rid_of_specials(white_space_fix(remove_articles(remove_punc(lower(s))))) + + def exact_match_score(self, prediction: str, ground_truth: str) -> int: + return int(self.normalize_answer(prediction) == self.normalize_answer(ground_truth)) + + def _f1_score(self, prediction: str, ground_truth: str) -> float: + prediction_tokens = self.normalize_answer(prediction).split() + ground_truth_tokens = self.normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0.0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + return (2 * precision * recall) / (precision + recall) + + def calculate_scores(self, predictions: list[str], ground_truths: list[str]) -> tuple[float, float]: + em_score: float = 0.0 + f1_score: float = 0.0 + + for i in range(len(predictions)): + ground_truth = ground_truths[i] + prediction = predictions[i] + em_score += self.exact_match_score(prediction, ground_truth) + f1_score += self._f1_score(prediction, ground_truth) + + em_score /= len(predictions) + f1_score /= len(predictions) + return em_score * 100, f1_score * 100 + + def lmap(self, f: Callable, x: list[Any]) -> list[Any]: + """list(map(f, x))""" + return list(map(f, x)) + + def is_logger(self) -> bool: + return self.trainer.global_rank <= 0 + + def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + lm_labels: Optional[torch.Tensor] = None) -> Any: + return self.model( + input_ids, + attention_mask=attention_mask, + labels=lm_labels, + ) + + def get_last_layer(self) -> torch.nn.Module: + return self.model.lm_head + + """ + def _step(self, batch): + lm_labels = batch["target_ids"] + lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100 + outputs = self( + input_ids=batch["source_ids"], + attention_mask=batch["source_mask"], + lm_labels=lm_labels, + ) + + loss = outputs[0] + return loss + + def valid_step(self, batch): + lm_labels = batch["label_ids"].clone().detach() + lm_labels[source_nonprompt_mask == 0] = -100 + outputs = self( + input_ids=batch["label_ids"], + attention_mask=batch["label_mask"], + lm_labels=lm_labels, + ) + + loss = outputs[0] + print(loss) + return loss + + + + def ids_to_clean_text(self, generated_ids): + gen_text = self.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return self.lmap(str.strip, gen_text) + + def _generative_step_finetune(self, batch, batch_idx): + loss = self.valid_step(batch) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.validation +=1 + self.validation_loss += loss + average_loss = self.validation_loss / self.validation + ppl = torch.exp(average_loss) + self.log('validation_ppl', ppl, prog_bar=True, logger=True) + + source = self.ids_to_clean_text(batch["source_ids"]) + generated_ids = self.model.generate( + batch["source_ids"], + attention_mask=batch["source_mask"], + use_cache=True, + max_length=self.hparams.max_input_length + 5, + num_beams=2, + early_stopping=True + ) + targets = self.ids_to_clean_text(batch["target_ids"]) + + generated_ids = torch.transpose(torch.transpose(generated_ids,0,1)[self.hparams.max_input_length:],0,1) + preds = self.ids_to_clean_text(generated_ids) + clean_preds = [] + for text in preds: + if "." in text: + clean_preds.append(text[:text.find(".")+1]) + else: + clean_preds.append(text) + print("clean_preds",clean_preds) + print("targets",targets) + + em_score, f1_score = self.calculate_scores(clean_preds, targets) + print(em_score, f1_score, ppl) + self.log('EM score', em_score, prog_bar=True, logger=True) + self.log('F1 score', f1_score, prog_bar=True, logger=True) + + + def _generative_step(self, batch, batch_idx, dataloader_idx=-1): + loss = self._step(batch) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + if dataloader_idx == 0: + self.unchanged +=1 + self.unchanged_loss += loss + average_loss = self.unchanged_loss / self.unchanged + ppl = torch.exp(average_loss) + self.log('UnC_ppl', ppl, prog_bar=True, logger=True) + print('UnC_ppl', ppl) + elif dataloader_idx == 1: + self.changed +=1 + self.changed_loss += loss + average_loss = self.changed_loss / self.changed + ppl = torch.exp(average_loss) + self.log('C_ppl', ppl, prog_bar=True, logger=True) + print('C_ppl', ppl) + else: + self.invariant +=1 + self.invariant_loss += loss + average_loss = self.invariant_loss / self.invariant + ppl = torch.exp(average_loss) + self.log('IL_ppl', ppl, prog_bar=True, logger=True) + print('IL_ppl', ppl) + + def training_step(self, batch, batch_idx): + loss = self._step(batch) + self.log("loss", loss) + return loss + + def on_train_epoch_end(self): + if self.hparams.mode=='pretrain_brute': + self.dataset_index+=1 + if self.dataset_index==self.hparams.num_files: + self.global_epoch+=1 + self.log('global_epoch', self.global_epoch, prog_bar=True, logger=True) + self.dataset_index=0 + self.train_dataloader() + if self.hparams.method=='mixreview': + train_set = self.train_dataloader().dataset + self.epoch+=1 + + def validation_step(self, batch, batch_idx, dataloader_idx=-1): + if self.hparams.mode == 'finetune': + return self._generative_step_finetune(batch, batch_idx, dataloader_idx) + return self._generative_step(batch, batch_idx, dataloader_idx) + + def configure_optimizers(self, train_len=None): + "Prepare optimizer and schedule (linear warmup and decay)" + if self.hparams.method=='recadam': + no_decay = ["bias", "LayerNorm.weight"] + model_type = 'gpt2' + recadam_anneal_w = 1.0 + recadam_anneal_fun = 'sigmoid' + recadam_anneal_k = 0.5 + recadam_anneal_t0 = 250 + recadam_pretrain_cof = 5000.0 + new_model = self.model + pretrained_model = self.pretrained_model + optimizer_grouped_parameters = [ + { + "params": [p for n, p in new_model.named_parameters() if + not any(nd in n for nd in no_decay) and model_type in n], + "weight_decay": self.hparams.weight_decay, + "anneal_w": recadam_anneal_w, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + not any(nd in p_n for nd in no_decay) and model_type in p_n] + }, + { + "params": [p for n, p in new_model.named_parameters() if + not any(nd in n for nd in no_decay) and model_type not in n], + "weight_decay": self.hparams.weight_decay, + "anneal_w": 0.0, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + not any(nd in p_n for nd in no_decay) and model_type not in p_n] + }, + { + "params": [p for n, p in new_model.named_parameters() if + any(nd in n for nd in no_decay) and model_type in n], + "weight_decay": 0.0, + "anneal_w": recadam_anneal_w, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + any(nd in p_n for nd in no_decay) and model_type in p_n] + }, + { + "params": [p for n, p in new_model.named_parameters() if + any(nd in n for nd in no_decay) and model_type not in n], + "weight_decay": 0.0, + "anneal_w": 0.0, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + any(nd in p_n for nd in no_decay) and model_type not in p_n] + } + ] + optimizer = RecAdam(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon, + anneal_fun=recadam_anneal_fun, anneal_k=recadam_anneal_k, + anneal_t0=recadam_anneal_t0, pretrain_cof=recadam_pretrain_cof) + else: + model = self.model + + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + if self.hparams.accelerator is not None: + optimizer = deepspeed.ops.adam.FusedAdam(optimizer_grouped_parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) + else: + optimizer = Adafactor(optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False) + + if self.hparams.use_lr_scheduling: + if self.hparams.len_data==None: + len_data = len(self.train_dataloader()) + else: + len_data = int(self.hparams.len_data // self.hparams.train_batch_size) + denomniator = (self.hparams.n_gpu * self.hparams.gradient_accumulation_steps) + + steps_per_epoch = ( len_data // denomniator ) + 1 + schedule_scale_factor = 8 + total_num_steps = ( steps_per_epoch * self.hparams.num_train_epochs ) * self.hparams.num_files * schedule_scale_factor + + print(f'total number of steps : {total_num_steps}') + scheduler = WarmupDecayLR(optimizer, total_num_steps = total_num_steps ,warmup_max_lr = self.hparams.learning_rate, warmup_num_steps = int(total_num_steps * 0.1)) + return [optimizer], [{"scheduler": scheduler, "interval": "step", "name": "learning rate"}] + else: + return [optimizer] + + def train_dataloader(self): + if self.hparams.mode=='pretrain_brute': + train_dataset = Pretrain_Chunks(dataset_name=self.dataset_lst[self.dataset_index],tokenizer=self.tokenizer, input_length=self.hparams.max_input_length, output_length=self.hparams.max_output_length, args=self.hparams) + else: + train_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="train", args=self.hparams) + if self.hparams.method=='mixreview': + #mix_len = int(len(train_dataset) * self.mix_ratio * (self.mix_decay ** self.epoch)) + mix_len = int(len(train_dataset)) + pretrain_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="pretrain", args=self.hparams, length=mix_len) + mixed_dataset = ConcatDataset([train_dataset,pretrain_dataset]) + print("mix len is ", mix_len) + sampler = RandomSampler(mixed_dataset) + dataloader = DataLoader(mixed_dataset, sampler = sampler, batch_size=self.hparams.train_batch_size, drop_last=True, num_workers=self.hparams.num_workers) + print("dataset length is ", len(dataloader.dataset)) + else: + sampler = RandomSampler(train_dataset) + dataloader = DataLoader(train_dataset, sampler=sampler, batch_size=self.hparams.train_batch_size, drop_last=True, num_workers=self.hparams.num_workers) + return dataloader + + def val_dataloader(self): + validation_dataset_unchanged = self.get_dataset(tokenizer=self.tokenizer, type_path="validation", args=self.hparams, lama_type='unchanged') + validation_dataset_changed = self.get_dataset(tokenizer=self.tokenizer, type_path="validation", args=self.hparams, lama_type='changed') + return [DataLoader(validation_dataset_unchanged, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False), + DataLoader(validation_dataset_changed, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False), + ] + def test_dataloader(self): + test_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="test", args=self.hparams) + + return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False) + """ \ No newline at end of file diff --git a/modyn/models/GPT2/GPT2_Model_LoRA.py b/modyn/models/GPT2/GPT2_Model_LoRA.py new file mode 100644 index 000000000..d474d1474 --- /dev/null +++ b/modyn/models/GPT2/GPT2_Model_LoRA.py @@ -0,0 +1,1115 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Union, Any, Callable +from typing import Any, Union,List +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + + +if version.parse(torch.__version__) >= version.parse("1.6"): + is_amp_available = True + from torch.cuda.amp import autocast +else: + is_amp_available = False + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + Conv1D, + PreTrainedModel, + SequenceSummary, + find_pruneable_heads_and_indices, + prune_conv1d_layer, +) +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers import GPT2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt2" +_CONFIG_FOR_DOC = "GPT2Config" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +] + + +def load_tf_weights_in_gpt2( + model: torch.nn.Module, + config: Any, + gpt2_checkpoint_path: str +) -> torch.nn.Module: + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + +class LoRALayer(nn.Module): + def __init__( + self, + n_in: int, + n_out: Optional[int] = None, + adapter_dim: int = 16, + adapter_alpha: int = 32 + ) -> None: + super(LoRALayer, self).__init__() + if not n_out: + n_out = n_in + self.adapter_dim = adapter_dim + self.adapter_alpha = adapter_alpha + self.adapter_proj_1 = nn.Linear(n_in, adapter_dim, bias=False) + nn.init.normal_(self.adapter_proj_1.weight, std=0.02) + self.adapter_proj_2 = nn.Linear(adapter_dim, n_out, bias=False) + self.adapter_proj_2.weight.data.zero_() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + scale_factor = self.adapter_dim / self.adapter_alpha + result = torch.matmul(x, self.adapter_proj_1.weight.type_as(x).T) + return torch.matmul(result, self.adapter_proj_2.weight.type_as(x).T) * scale_factor + +class GPT2Attention(nn.Module): + def __init__( + self, + config: Any, # Replace `Any` with the specific config class if known + is_cross_attention: bool = False, + layer_idx: int = -1 +) -> None: + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.lora_attn_dim = 4 + self.lora_attn_alpha = 16 + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.q_lora = LoRALayer(self.embed_dim, adapter_dim=self.lora_attn_dim, adapter_alpha=self.lora_attn_alpha) + self.v_lora = LoRALayer(self.embed_dim, adapter_dim=self.lora_attn_dim, adapter_alpha=self.lora_attn_alpha) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads: set[int] = set() + + + def prune_heads(self, heads: List[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None + ) -> Any: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None + ) -> Any: + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + if is_amp_available: + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + else: + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads( + self, + tensor: torch.Tensor, + num_heads: int, + attn_head_size: int + ) -> torch.Tensor: + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(*new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads( + self, + tensor: torch.Tensor, + num_heads: int, + attn_head_size: int + ) -> torch.Tensor: + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False + )-> Any: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_delta = self.q_lora(hidden_states) + value_delta = self.v_lora(encoder_hidden_states) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_delta = self.q_lora(hidden_states) + value_delta = self.v_lora(hidden_states) + + query = query.contiguous() + query_delta + value = value.contiguous() + value_delta + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + """ + if use_cache is True: + present = (key, value) + else: + present = None + """ + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output,None )# type: ignore + if output_attentions: + outputs += (attn_weights,) # type: ignore + + return outputs # a, present, (attentions) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size: int, config: Any) -> None: + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + +class GPT2Block(nn.Module): + def __init__(self, config: Any, layer_idx: int = -1) -> None: + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1: nn.LayerNorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn: nn.Module = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2: nn.LayerNorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention: nn.Module = GPT2Attention(config, is_cross_attention=True) + self.ln_cross_attn: nn.LayerNorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp: nn.Module = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + + def __init__(self, *inputs: Any, **kwargs: Any) -> None: + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + def _set_gradient_checkpointing( + self, + module: Union[torch.nn.Module, Any], # Replace `Any` with the actual type of GPT2Model if available + value: bool = False + ) -> None: + """Enable or disable gradient checkpointing for the given module.""" + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + + +GPT2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be + passed as `input_ids`. + Indices can be obtained using [`GPT2Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which + have their past given to this model should not be passed as `input_ids` as they have already been + computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up + decoding (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + Example: + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained('gpt2-xl') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]} + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + Example: + ```python + # On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained('gpt2-large') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) + +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing: List[str] = ["attn.masked_bias"] + + def __init__(self, config: Any) -> None: + super().__init__(config) + + self.embed_dim: int = config.hidden_size + + self.wte: nn.Embedding = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe: nn.Embedding = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop: nn.Dropout = nn.Dropout(config.embd_pdrop) + self.h: nn.ModuleList = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f: nn.LayerNorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel: bool = False + self.device_map: Optional[Dict[int, List[int]]] = None + self.gradient_checkpointing: bool = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map: Optional[Dict[int, List[int]]] = None) -> None: + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device: str = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device: str = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self) -> None: + self.model_parallel = False + self.device_map = None + self.first_device: str = "cpu"# type: ignore[no-redef] + self.last_device: str = "cpu"# type: ignore[no-redef] + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self) -> nn.Embedding: + return self.wte + + def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[ + Tuple[ + torch.Tensor, + Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]], + Optional[Tuple[torch.Tensor, ...]], + Optional[Tuple[torch.Tensor, ...]], + Optional[Tuple[torch.Tensor, ...]], + ], + BaseModelOutputWithPastAndCrossAttentions, + ]: + # The implementation remains unchanged + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device # type: ignore[union-attr] + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) # type: ignore[list-item] + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents: Tuple[torch.Tensor, ...] = () + all_self_attentions: Tuple[torch.Tensor, ...] = () + all_cross_attentions: Tuple[torch.Tensor, ...] = () + all_hidden_states: Tuple[torch.Tensor, ...] = () + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module: Callable[..., Any]) -> Callable[..., Any]: + def custom_forward(*inputs: Any) -> Any: + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): # type: ignore[union-attr] + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) + + +class GPT2LMHeadModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing: List[str] = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config: Any) -> None: + super().__init__(config) + self.transformer: GPT2Model = GPT2Model(config) + self.lm_head: nn.Linear = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel: bool = False + self.device_map: Optional[Dict[int, List[int]]] = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map: Optional[Dict[int, List[int]]] = None) -> None: + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self) -> None: + self.transformer.deparallelize() + self.transformer = self.transformer.to("cuda") + self.lm_head = self.lm_head.to("cuda") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self) -> nn.Linear: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, + input_ids: torch.Tensor, + past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + token_type_ids = kwargs.get("token_type_ids", None) + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[ + Tuple[torch.Tensor, ...], + CausalLMOutputWithCrossAttentions, + ]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, ...], ...], + beam_idx: torch.Tensor, + ) -> Tuple[Tuple[torch.Tensor, ...], ...]: + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + diff --git a/modyn/models/GPT2/_init_.py b/modyn/models/GPT2/_init_.py new file mode 100644 index 000000000..f5987b52c --- /dev/null +++ b/modyn/models/GPT2/_init_.py @@ -0,0 +1,5 @@ +import os + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/__init__.py b/modyn/models/__init__.py index 8d92f68e8..668e9ee6c 100644 --- a/modyn/models/__init__.py +++ b/modyn/models/__init__.py @@ -12,6 +12,7 @@ from .rho_loss_twin_model.rho_loss_twin_model import RHOLOSSTwinModel # noqa: F401 from .smallyearbooknet.smallyearbooknet import SmallYearbookNet # noqa: F401 from .yearbooknet.yearbooknet import YearbookNet # noqa: F401 +from .GPT2.GPT2 import GPT2 #noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") diff --git a/modyn/models/tokenizers/__init__.py b/modyn/models/tokenizers/__init__.py index 93e8d9c8f..47884e93e 100644 --- a/modyn/models/tokenizers/__init__.py +++ b/modyn/models/tokenizers/__init__.py @@ -3,7 +3,7 @@ import os from .distill_bert_tokenizer import DistilBertTokenizerTransform # noqa: F401 - +from .gpt2_tokenizer import GPT2TokenizerTransform # noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") __all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/tokenizers/gpt2_tokenizer.py b/modyn/models/tokenizers/gpt2_tokenizer.py new file mode 100644 index 000000000..204abce03 --- /dev/null +++ b/modyn/models/tokenizers/gpt2_tokenizer.py @@ -0,0 +1,28 @@ +import torch +from transformers import GPT2Tokenizer + +class GPT2TokenizerTransform: + def __init__(self, max_token_length: int = 300): + # Load the GPT-2 tokenizer + self.max_token_length = max_token_length + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") + # Set the pad token to the eos token to avoid padding errors + self.tokenizer.add_special_tokens({ + "eos_token": "", + "bos_token": "", + "unk_token": "", + "pad_token": "", + "mask_token": "" + }) + self.tokenizer.padding_side = "left" + + def __call__(self, sample: str) -> torch.Tensor: + # Make the class callable to use it as Torch Transform + tokens = self.tokenizer( + sample, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" + ) + # Create a tensor whose first dimension is the input_ids and the second is the attention_mask + data = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) + data = torch.squeeze(data, dim=0) # First shape dim is always 1, since the input is just one string + return data + \ No newline at end of file diff --git a/modyn/protos/storage.proto b/modyn/protos/storage.proto index 4b6fc9dde..37fa4bff8 100644 --- a/modyn/protos/storage.proto +++ b/modyn/protos/storage.proto @@ -4,6 +4,7 @@ package modyn.storage; service Storage { rpc Get(GetRequest) returns (stream GetResponse) {} + rpc GetNL(GetRequest) returns (stream GetResponseNoLabels) {} rpc GetNewDataSince(GetNewDataSinceRequest) returns (stream GetNewDataSinceResponse) {} rpc GetDataInInterval(GetDataInIntervalRequest) @@ -25,6 +26,7 @@ service Storage { message GetRequest { string dataset_id = 1; repeated int64 keys = 2; + bool include_labels = 3; // Added this line } message GetResponse { @@ -33,6 +35,13 @@ message GetResponse { repeated int64 labels = 3; } + +message GetResponseNoLabels { + repeated bytes samples = 1; + repeated int64 keys = 2; + +} + // https://github.com/grpc/grpc/issues/15937 message GetCurrentTimestampRequest {} diff --git a/modyn/protos/trainer_server.proto b/modyn/protos/trainer_server.proto index 0c7343c70..6eda308c4 100644 --- a/modyn/protos/trainer_server.proto +++ b/modyn/protos/trainer_server.proto @@ -59,6 +59,7 @@ message StartTrainingRequest { bool enable_accurate_gpu_measurements = 25; int64 record_loss_every = 26; bool drop_last_batch = 27; + bool generative=28; } message StartTrainingResponse { diff --git a/modyn/storage/include/internal/grpc/draft.hpp b/modyn/storage/include/internal/grpc/draft.hpp new file mode 100644 index 000000000..f233dd948 --- /dev/null +++ b/modyn/storage/include/internal/grpc/draft.hpp @@ -0,0 +1,152 @@ +//"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +// Adapted function to send sample data without labels +template > +static void send_sample_data_for_keys_and_file_NL( // NOLINT(readability-function-cognitive-complexity) //TODO Adaptar esto + WriterT* writer, std::mutex& writer_mutex, const std::vector& sample_keys, + const DatasetData& dataset_data, soci::session& session, int64_t /*sample_batch_size*/) { + + // Note that we currently ignore the sample batch size here, under the assumption that users do not request more + // keys than this + try { + const uint64_t num_keys = sample_keys.size(); + + if (num_keys == 0) { + SPDLOG_ERROR("num_keys is 0, this should not have happened. Exiting send_sample_data_for_keys_and_file_NL"); + return; + } + + // Removed labels-related vectors + // std::vector sample_labels(num_keys); + std::vector sample_indices(num_keys); + std::vector sample_fileids(num_keys); + + const std::string sample_query = fmt::format( + "SELECT sample_index, file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN ({}) ORDER BY file_id", + fmt::join(sample_keys, ",")); + session << sample_query, + soci::into(sample_indices), + soci::into(sample_fileids), + soci::use(dataset_data.dataset_id); + + if (sample_fileids.size() != num_keys) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException( + fmt::format("Got back {} samples from DB, while asking for {} keys. You might have asked for duplicate " + "keys, which is not supported.", + sample_fileids.size(), num_keys)); + } + + int64_t current_file_id = sample_fileids.at(0); + uint64_t current_file_start_idx = 0; + std::string current_file_path; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}, current_file_id = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, current_file_id, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config); + auto filesystem_wrapper = + get_filesystem_wrapper(static_cast(dataset_data.filesystem_wrapper_type)); + + auto file_wrapper = + get_file_wrapper(current_file_path, static_cast(dataset_data.file_wrapper_type), + file_wrapper_config_node, filesystem_wrapper); + + for (uint64_t sample_idx = 0; sample_idx < num_keys; ++sample_idx) { + const int64_t& sample_fileid = sample_fileids.at(sample_idx); + + if (sample_fileid != current_file_id) { + // 1. Prepare response without labels + const std::vector file_indexes( + sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.begin() + static_cast(sample_idx)); + std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + data.clear(); + data.shrink_to_fit(); + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.begin() + static_cast(sample_idx)); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.begin() + static_cast(sample_idx)); + + // 2. Send response + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + + // 3. Update state + current_file_id = sample_fileid; + current_file_path = ""; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + const int64_t& previous_fid = sample_fileids.at(sample_idx - 1); + SPDLOG_ERROR( + fmt::format("num_keys = {}, sample_idx = {}, previous_fid = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, sample_idx, previous_fid, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + file_wrapper->set_file_path(current_file_path); + current_file_start_idx = sample_idx; + } + } + + // Send leftovers without labels + const std::vector file_indexes(sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.end()); + const std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.end()); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.end()); + + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in send_sample_data_for_keys_and_file_NL: {}", e.what()); + SPDLOG_ERROR("Propagating error up the call chain to handle gRPC calls."); + throw; + } +} diff --git a/modyn/storage/include/internal/grpc/storage_service_impl.hpp b/modyn/storage/include/internal/grpc/storage_service_impl.hpp index 875bf2e53..0fc8189ee 100644 --- a/modyn/storage/include/internal/grpc/storage_service_impl.hpp +++ b/modyn/storage/include/internal/grpc/storage_service_impl.hpp @@ -74,6 +74,8 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { Status Get(ServerContext* context, const modyn::storage::GetRequest* request, ServerWriter* writer) override; + Status GetNL(ServerContext* context, const modyn::storage::GetRequest* request, + ServerWriter* writer) override; Status GetNewDataSince(ServerContext* context, const modyn::storage::GetNewDataSinceRequest* request, ServerWriter* writer) override; Status GetDataInInterval(ServerContext* context, const modyn::storage::GetDataInIntervalRequest* request, @@ -134,6 +136,55 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { return {StatusCode::INTERNAL, fmt::format("Error in Get: {}", e.what())}; } } + //IMPORTANT SECOND FUNCTION STARTS HERE + template + Status Get_Impl_NL( // NOLINT (readability-identifier-naming) + ServerContext* /*context*/, const modyn::storage::GetRequest* request, WriterT* writer) { + try { + soci::session session = storage_database_connection_.get_session(); + + // Check if the dataset exists + std::string dataset_name = request->dataset_id(); + const DatasetData dataset_data = get_dataset_data(session, dataset_name); + + SPDLOG_INFO(fmt::format("Received GetRequest for dataset {} (id = {}) with {} keys.", dataset_name, + dataset_data.dataset_id, request->keys_size())); + + if (dataset_data.dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", request->dataset_id()); + session.close(); + return {StatusCode::OK, "Dataset does not exist."}; + } + + const auto keys_size = static_cast(request->keys_size()); + if (keys_size == 0) { + return {StatusCode::OK, "No keys provided."}; + } + + std::vector request_keys; + request_keys.reserve(keys_size); + std::copy(request->keys().begin(), request->keys().end(), std::back_inserter(request_keys)); + + send_sample_data_from_keys_NL(writer, request_keys, dataset_data); //Llegamos hasta aqui + + // sqlite causes memory leaks otherwise + if (session.get_backend_name() != "sqlite3" && session.is_connected()) { + session.close(); + } + + return {StatusCode::OK, "Data retrieved."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in Get: {}", e.what()); + return {StatusCode::INTERNAL, fmt::format("Error in Get: {}", e.what())}; + } + } + //ENDS HERE + + + + + + template Status GetNewDataSince_Impl( // NOLINT (readability-identifier-naming) @@ -365,6 +416,66 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { } } } + //EMPIEZA + template > + void send_sample_data_from_keys_NL(WriterT* writer, const std::vector& request_keys, + const DatasetData& dataset_data) { + // Create mutex to protect the writer from concurrent writes as this is not supported by gRPC + std::mutex writer_mutex; + + if (disable_multithreading_) { + const std::vector::const_iterator begin = request_keys.begin(); // NOLINT (modernize-use-auto) + const std::vector::const_iterator end = request_keys.end(); // NOLINT (modernize-use-auto) + + get_samples_and_send_NL(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_);//llegamos aqui + + } else { + std::vector thread_exceptions(retrieval_threads_); + std::mutex exception_mutex; + std::vector::const_iterator, std::vector::const_iterator>> + its_per_thread = get_keys_per_thread(request_keys, retrieval_threads_); + std::vector retrieval_threads_vector(retrieval_threads_); + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + const std::vector::const_iterator begin = its_per_thread[thread_id].first; + const std::vector::const_iterator end = its_per_thread[thread_id].second; + + retrieval_threads_vector[thread_id] = std::thread([thread_id, begin, end, writer, &writer_mutex, &dataset_data, + &thread_exceptions, &exception_mutex, this]() { + try { + get_samples_and_send_NL(begin, end, writer, &writer_mutex, &dataset_data, &config_, + sample_batch_size_); + } catch (const std::exception& e) { + const std::lock_guard lock(exception_mutex); + spdlog::error( + fmt::format("Error in thread {} started by send_sample_data_from_keys: {}", thread_id, e.what())); + thread_exceptions[thread_id] = std::current_exception(); + } + }); + } + + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + if (retrieval_threads_vector[thread_id].joinable()) { + retrieval_threads_vector[thread_id].join(); + } + } + retrieval_threads_vector.clear(); + // In order for the gRPC call to return an error, we need to rethrow the threaded exceptions. + for (auto& e_ptr : thread_exceptions) { + if (e_ptr) { + try { + std::rethrow_exception(e_ptr); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error while unwinding thread: {}\nPropagating it up the call chain.", e.what()); + throw; + } + } + } + } + } +//ACABA AQUI + + + template > void send_file_ids_and_labels(WriterT* writer, const int64_t dataset_id, const int64_t start_timestamp = -1, @@ -690,6 +801,183 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { } } + + + + + + //"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +//"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +// Adapted function to send sample data without labels +template > +static void send_sample_data_for_keys_and_file_NL( // NOLINT(readability-function-cognitive-complexity) //TODO Adaptar esto + WriterT* writer, std::mutex& writer_mutex, const std::vector& sample_keys, + const DatasetData& dataset_data, soci::session& session, int64_t /*sample_batch_size*/) { + + // Note that we currently ignore the sample batch size here, under the assumption that users do not request more + // keys than this + try { + const uint64_t num_keys = sample_keys.size(); + + if (num_keys == 0) { + SPDLOG_ERROR("num_keys is 0, this should not have happened. Exiting send_sample_data_for_keys_and_file_NL"); + return; + } + + // Removed labels-related vectors + // std::vector sample_labels(num_keys); + std::vector sample_indices(num_keys); + std::vector sample_fileids(num_keys); + + const std::string sample_query = fmt::format( + "SELECT sample_index, file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN ({}) ORDER BY file_id", + fmt::join(sample_keys, ",")); + session << sample_query, + soci::into(sample_indices), + soci::into(sample_fileids), + soci::use(dataset_data.dataset_id); + + if (sample_fileids.size() != num_keys) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException( + fmt::format("Got back {} samples from DB, while asking for {} keys. You might have asked for duplicate " + "keys, which is not supported.", + sample_fileids.size(), num_keys)); + } + + int64_t current_file_id = sample_fileids.at(0); + uint64_t current_file_start_idx = 0; + std::string current_file_path; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}, current_file_id = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, current_file_id, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config); + auto filesystem_wrapper = + get_filesystem_wrapper(static_cast(dataset_data.filesystem_wrapper_type)); + + auto file_wrapper = + get_file_wrapper(current_file_path, static_cast(dataset_data.file_wrapper_type), + file_wrapper_config_node, filesystem_wrapper); + + for (uint64_t sample_idx = 0; sample_idx < num_keys; ++sample_idx) { + const int64_t& sample_fileid = sample_fileids.at(sample_idx); + + if (sample_fileid != current_file_id) { + // 1. Prepare response without labels + const std::vector file_indexes( + sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.begin() + static_cast(sample_idx)); + std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + data.clear(); + data.shrink_to_fit(); + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.begin() + static_cast(sample_idx)); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.begin() + static_cast(sample_idx)); + + // 2. Send response + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + + // 3. Update state + current_file_id = sample_fileid; + current_file_path = ""; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + const int64_t& previous_fid = sample_fileids.at(sample_idx - 1); + SPDLOG_ERROR( + fmt::format("num_keys = {}, sample_idx = {}, previous_fid = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, sample_idx, previous_fid, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + file_wrapper->set_file_path(current_file_path); + current_file_start_idx = sample_idx; + } + } + + // Send leftovers without labels + const std::vector file_indexes(sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.end()); + const std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.end()); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.end()); + + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in send_sample_data_for_keys_and_file_NL: {}", e.what()); + SPDLOG_ERROR("Propagating error up the call chain to handle gRPC calls."); + throw; + } +} + + //"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + + + + + + + + + + + + + + + + template static void get_samples_and_send(const std::vector::const_iterator begin, const std::vector::const_iterator end, WriterT* writer, @@ -705,7 +993,21 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { sample_batch_size); session.close(); } - + template + static void get_samples_and_send_NL(const std::vector::const_iterator begin, + const std::vector::const_iterator end, WriterT* writer, + std::mutex* writer_mutex, const DatasetData* dataset_data, const YAML::Node* config, + int64_t sample_batch_size) { + if (begin >= end) { + return; + } + const StorageDatabaseConnection storage_database_connection(*config); + soci::session session = storage_database_connection.get_session(); + const std::vector sample_keys(begin, end); + send_sample_data_for_keys_and_file_NL(writer, *writer_mutex, sample_keys, *dataset_data, session, + sample_batch_size); + session.close(); + } static std::string get_timestamp_condition(const int64_t start_timestamp = -1, const int64_t end_timestamp = -1) { std::string timestamp_filter; if (start_timestamp >= 0 && end_timestamp == -1) { diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.py b/modyn/storage/internal/grpc/generated/storage_pb2.py index 8d59a8b1f..b24fd735d 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: storage.proto -# Protobuf Python Version: 5.26.1 +# Protobuf Python Version: 5.27.2 """Generated protocol buffer code.""" - from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool + from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder @@ -14,55 +15,57 @@ _sym_db = _symbol_database.Default() + from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\rstorage.proto\x12\rmodyn.storage".\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\x1c\n\x1aGetCurrentTimestampRequest"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe2\x07\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse"\x00\x62\x06proto3' -) + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rstorage.proto\x12\rmodyn.storage\"F\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x16\n\x0einclude_labels\x18\x03 \x01(\x08\"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"4\n\x13GetResponseNoLabels\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"\x1c\n\x1aGetCurrentTimestampRequest\"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03\"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03\"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03\"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03\"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xae\x08\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse\"\x00\x30\x01\x12J\n\x05GetNL\x12\x19.modyn.storage.GetRequest\x1a\".modyn.storage.GetResponseNoLabels\"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse\"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse\"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse\"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse\"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse\"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse\"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse\"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse\"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "storage_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'storage_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals["_GETREQUEST"]._serialized_start = 32 - _globals["_GETREQUEST"]._serialized_end = 78 - _globals["_GETRESPONSE"]._serialized_start = 80 - _globals["_GETRESPONSE"]._serialized_end = 140 - _globals["_GETCURRENTTIMESTAMPREQUEST"]._serialized_start = 142 - _globals["_GETCURRENTTIMESTAMPREQUEST"]._serialized_end = 170 - _globals["_GETNEWDATASINCEREQUEST"]._serialized_start = 172 - _globals["_GETNEWDATASINCEREQUEST"]._serialized_end = 235 - _globals["_GETNEWDATASINCERESPONSE"]._serialized_start = 237 - _globals["_GETNEWDATASINCERESPONSE"]._serialized_end = 312 - _globals["_GETDATAININTERVALREQUEST"]._serialized_start = 314 - _globals["_GETDATAININTERVALREQUEST"]._serialized_end = 408 - _globals["_GETDATAININTERVALRESPONSE"]._serialized_start = 410 - _globals["_GETDATAININTERVALRESPONSE"]._serialized_end = 487 - _globals["_GETDATAPERWORKERREQUEST"]._serialized_start = 490 - _globals["_GETDATAPERWORKERREQUEST"]._serialized_end = 673 - _globals["_GETDATAPERWORKERRESPONSE"]._serialized_start = 675 - _globals["_GETDATAPERWORKERRESPONSE"]._serialized_end = 715 - _globals["_GETDATASETSIZEREQUEST"]._serialized_start = 718 - _globals["_GETDATASETSIZEREQUEST"]._serialized_end = 857 - _globals["_GETDATASETSIZERESPONSE"]._serialized_start = 859 - _globals["_GETDATASETSIZERESPONSE"]._serialized_end = 918 - _globals["_DATASETAVAILABLEREQUEST"]._serialized_start = 920 - _globals["_DATASETAVAILABLEREQUEST"]._serialized_end = 965 - _globals["_DATASETAVAILABLERESPONSE"]._serialized_start = 967 - _globals["_DATASETAVAILABLERESPONSE"]._serialized_end = 1012 - _globals["_REGISTERNEWDATASETREQUEST"]._serialized_start = 1015 - _globals["_REGISTERNEWDATASETREQUEST"]._serialized_end = 1270 - _globals["_REGISTERNEWDATASETRESPONSE"]._serialized_start = 1272 - _globals["_REGISTERNEWDATASETRESPONSE"]._serialized_end = 1317 - _globals["_GETCURRENTTIMESTAMPRESPONSE"]._serialized_start = 1319 - _globals["_GETCURRENTTIMESTAMPRESPONSE"]._serialized_end = 1367 - _globals["_DELETEDATASETRESPONSE"]._serialized_start = 1369 - _globals["_DELETEDATASETRESPONSE"]._serialized_end = 1409 - _globals["_DELETEDATAREQUEST"]._serialized_start = 1411 - _globals["_DELETEDATAREQUEST"]._serialized_end = 1464 - _globals["_DELETEDATARESPONSE"]._serialized_start = 1466 - _globals["_DELETEDATARESPONSE"]._serialized_end = 1503 - _globals["_STORAGE"]._serialized_start = 1506 - _globals["_STORAGE"]._serialized_end = 2500 + DESCRIPTOR._loaded_options = None + _globals['_GETREQUEST']._serialized_start=32 + _globals['_GETREQUEST']._serialized_end=102 + _globals['_GETRESPONSE']._serialized_start=104 + _globals['_GETRESPONSE']._serialized_end=164 + _globals['_GETRESPONSENOLABELS']._serialized_start=166 + _globals['_GETRESPONSENOLABELS']._serialized_end=218 + _globals['_GETCURRENTTIMESTAMPREQUEST']._serialized_start=220 + _globals['_GETCURRENTTIMESTAMPREQUEST']._serialized_end=248 + _globals['_GETNEWDATASINCEREQUEST']._serialized_start=250 + _globals['_GETNEWDATASINCEREQUEST']._serialized_end=313 + _globals['_GETNEWDATASINCERESPONSE']._serialized_start=315 + _globals['_GETNEWDATASINCERESPONSE']._serialized_end=390 + _globals['_GETDATAININTERVALREQUEST']._serialized_start=392 + _globals['_GETDATAININTERVALREQUEST']._serialized_end=486 + _globals['_GETDATAININTERVALRESPONSE']._serialized_start=488 + _globals['_GETDATAININTERVALRESPONSE']._serialized_end=565 + _globals['_GETDATAPERWORKERREQUEST']._serialized_start=568 + _globals['_GETDATAPERWORKERREQUEST']._serialized_end=751 + _globals['_GETDATAPERWORKERRESPONSE']._serialized_start=753 + _globals['_GETDATAPERWORKERRESPONSE']._serialized_end=793 + _globals['_GETDATASETSIZEREQUEST']._serialized_start=796 + _globals['_GETDATASETSIZEREQUEST']._serialized_end=935 + _globals['_GETDATASETSIZERESPONSE']._serialized_start=937 + _globals['_GETDATASETSIZERESPONSE']._serialized_end=996 + _globals['_DATASETAVAILABLEREQUEST']._serialized_start=998 + _globals['_DATASETAVAILABLEREQUEST']._serialized_end=1043 + _globals['_DATASETAVAILABLERESPONSE']._serialized_start=1045 + _globals['_DATASETAVAILABLERESPONSE']._serialized_end=1090 + _globals['_REGISTERNEWDATASETREQUEST']._serialized_start=1093 + _globals['_REGISTERNEWDATASETREQUEST']._serialized_end=1348 + _globals['_REGISTERNEWDATASETRESPONSE']._serialized_start=1350 + _globals['_REGISTERNEWDATASETRESPONSE']._serialized_end=1395 + _globals['_GETCURRENTTIMESTAMPRESPONSE']._serialized_start=1397 + _globals['_GETCURRENTTIMESTAMPRESPONSE']._serialized_end=1445 + _globals['_DELETEDATASETRESPONSE']._serialized_start=1447 + _globals['_DELETEDATASETRESPONSE']._serialized_end=1487 + _globals['_DELETEDATAREQUEST']._serialized_start=1489 + _globals['_DELETEDATAREQUEST']._serialized_end=1542 + _globals['_DELETEDATARESPONSE']._serialized_start=1544 + _globals['_DELETEDATARESPONSE']._serialized_end=1581 + _globals['_STORAGE']._serialized_start=1584 + _globals['_STORAGE']._serialized_end=2654 # @@protoc_insertion_point(module_scope) diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.pyi b/modyn/storage/internal/grpc/generated/storage_pb2.pyi index cf286ada1..a6c340801 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.pyi +++ b/modyn/storage/internal/grpc/generated/storage_pb2.pyi @@ -18,7 +18,10 @@ class GetRequest(google.protobuf.message.Message): DATASET_ID_FIELD_NUMBER: builtins.int KEYS_FIELD_NUMBER: builtins.int + INCLUDE_LABELS_FIELD_NUMBER: builtins.int dataset_id: builtins.str + include_labels: builtins.bool + """Added this line""" @property def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__( @@ -26,8 +29,9 @@ class GetRequest(google.protobuf.message.Message): *, dataset_id: builtins.str = ..., keys: collections.abc.Iterable[builtins.int] | None = ..., + include_labels: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "include_labels", b"include_labels", "keys", b"keys"]) -> None: ... global___GetRequest = GetRequest @@ -51,12 +55,30 @@ class GetResponse(google.protobuf.message.Message): keys: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"]) -> None: ... global___GetResponse = GetResponse +@typing.final +class GetResponseNoLabels(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SAMPLES_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + @property + def samples(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + samples: collections.abc.Iterable[builtins.bytes] | None = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "samples", b"samples"]) -> None: ... + +global___GetResponseNoLabels = GetResponseNoLabels + @typing.final class GetCurrentTimestampRequest(google.protobuf.message.Message): """https://github.com/grpc/grpc/issues/15937""" @@ -83,9 +105,7 @@ class GetNewDataSinceRequest(google.protobuf.message.Message): dataset_id: builtins.str = ..., timestamp: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"]) -> None: ... global___GetNewDataSinceRequest = GetNewDataSinceRequest @@ -109,9 +129,7 @@ class GetNewDataSinceResponse(google.protobuf.message.Message): timestamps: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... global___GetNewDataSinceResponse = GetNewDataSinceResponse @@ -132,12 +150,7 @@ class GetDataInIntervalRequest(google.protobuf.message.Message): start_timestamp: builtins.int = ..., end_timestamp: builtins.int = ..., ) -> None: ... - def ClearField( - self, - field_name: typing.Literal[ - "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp" - ], - ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... global___GetDataInIntervalRequest = GetDataInIntervalRequest @@ -161,9 +174,7 @@ class GetDataInIntervalResponse(google.protobuf.message.Message): timestamps: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... global___GetDataInIntervalResponse = GetDataInIntervalResponse @@ -193,46 +204,12 @@ class GetDataPerWorkerRequest(google.protobuf.message.Message): start_timestamp: builtins.int | None = ..., end_timestamp: builtins.int | None = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "dataset_id", - b"dataset_id", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - "total_workers", - b"total_workers", - "worker_id", - b"worker_id", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp", "total_workers", b"total_workers", "worker_id", b"worker_id"]) -> None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"] - ) -> typing.Literal["end_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"] - ) -> typing.Literal["start_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... global___GetDataPerWorkerRequest = GetDataPerWorkerRequest @@ -272,42 +249,12 @@ class GetDatasetSizeRequest(google.protobuf.message.Message): start_timestamp: builtins.int | None = ..., end_timestamp: builtins.int | None = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "dataset_id", - b"dataset_id", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"] - ) -> typing.Literal["end_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"] - ) -> typing.Literal["start_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... global___GetDatasetSizeRequest = GetDatasetSizeRequest @@ -394,29 +341,7 @@ class RegisterNewDatasetRequest(google.protobuf.message.Message): ignore_last_timestamp: builtins.bool = ..., file_watcher_interval: builtins.int = ..., ) -> None: ... - def ClearField( - self, - field_name: typing.Literal[ - "base_path", - b"base_path", - "dataset_id", - b"dataset_id", - "description", - b"description", - "file_watcher_interval", - b"file_watcher_interval", - "file_wrapper_config", - b"file_wrapper_config", - "file_wrapper_type", - b"file_wrapper_type", - "filesystem_wrapper_type", - b"filesystem_wrapper_type", - "ignore_last_timestamp", - b"ignore_last_timestamp", - "version", - b"version", - ], - ) -> None: ... + def ClearField(self, field_name: typing.Literal["base_path", b"base_path", "dataset_id", b"dataset_id", "description", b"description", "file_watcher_interval", b"file_watcher_interval", "file_wrapper_config", b"file_wrapper_config", "file_wrapper_type", b"file_wrapper_type", "filesystem_wrapper_type", b"filesystem_wrapper_type", "ignore_last_timestamp", b"ignore_last_timestamp", "version", b"version"]) -> None: ... global___RegisterNewDatasetRequest = RegisterNewDatasetRequest diff --git a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py index cae27014a..7cda0b538 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py @@ -1,34 +1,27 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" - +import grpc import warnings -import grpc import modyn.storage.internal.grpc.generated.storage_pb2 as storage__pb2 -GRPC_GENERATED_VERSION = "1.63.0" +GRPC_GENERATED_VERSION = '1.67.1' GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" _version_not_supported = False try: from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) except ImportError: _version_not_supported = True if _version_not_supported: - warnings.warn( - f"The grpc package installed is at version {GRPC_VERSION}," - + f" but the generated code in storage_pb2_grpc.py depends on" - + f" grpcio>={GRPC_GENERATED_VERSION}." - + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" - + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in storage_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' ) @@ -42,65 +35,60 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Get = channel.unary_stream( - "/modyn.storage.Storage/Get", - request_serializer=storage__pb2.GetRequest.SerializeToString, - response_deserializer=storage__pb2.GetResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/Get', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponse.FromString, + _registered_method=True) + self.GetNL = channel.unary_stream( + '/modyn.storage.Storage/GetNL', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponseNoLabels.FromString, + _registered_method=True) self.GetNewDataSince = channel.unary_stream( - "/modyn.storage.Storage/GetNewDataSince", - request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, - response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetNewDataSince', + request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, + response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, + _registered_method=True) self.GetDataInInterval = channel.unary_stream( - "/modyn.storage.Storage/GetDataInInterval", - request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, - response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetDataInInterval', + request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, + _registered_method=True) self.GetDataPerWorker = channel.unary_stream( - "/modyn.storage.Storage/GetDataPerWorker", - request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, - response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetDataPerWorker', + request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, + _registered_method=True) self.GetDatasetSize = channel.unary_unary( - "/modyn.storage.Storage/GetDatasetSize", - request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, - response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetDatasetSize', + request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, + response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, + _registered_method=True) self.CheckAvailability = channel.unary_unary( - "/modyn.storage.Storage/CheckAvailability", - request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, - response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/CheckAvailability', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, + _registered_method=True) self.RegisterNewDataset = channel.unary_unary( - "/modyn.storage.Storage/RegisterNewDataset", - request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, - response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/RegisterNewDataset', + request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, + response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, + _registered_method=True) self.GetCurrentTimestamp = channel.unary_unary( - "/modyn.storage.Storage/GetCurrentTimestamp", - request_serializer=storage__pb2.GetCurrentTimestampRequest.SerializeToString, - response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetCurrentTimestamp', + request_serializer=storage__pb2.GetCurrentTimestampRequest.SerializeToString, + response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, + _registered_method=True) self.DeleteDataset = channel.unary_unary( - "/modyn.storage.Storage/DeleteDataset", - request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, - response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/DeleteDataset', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, + _registered_method=True) self.DeleteData = channel.unary_unary( - "/modyn.storage.Storage/DeleteData", - request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, - response_deserializer=storage__pb2.DeleteDataResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/DeleteData', + request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDataResponse.FromString, + _registered_method=True) class StorageServicer(object): @@ -109,142 +97,153 @@ class StorageServicer(object): def Get(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetNL(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetNewDataSince(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetDataInInterval(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetDataPerWorker(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetDatasetSize(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def CheckAvailability(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def RegisterNewDataset(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetCurrentTimestamp(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def DeleteDataset(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def DeleteData(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_StorageServicer_to_server(servicer, server): rpc_method_handlers = { - "Get": grpc.unary_stream_rpc_method_handler( - servicer.Get, - request_deserializer=storage__pb2.GetRequest.FromString, - response_serializer=storage__pb2.GetResponse.SerializeToString, - ), - "GetNewDataSince": grpc.unary_stream_rpc_method_handler( - servicer.GetNewDataSince, - request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, - response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, - ), - "GetDataInInterval": grpc.unary_stream_rpc_method_handler( - servicer.GetDataInInterval, - request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, - response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, - ), - "GetDataPerWorker": grpc.unary_stream_rpc_method_handler( - servicer.GetDataPerWorker, - request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, - response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, - ), - "GetDatasetSize": grpc.unary_unary_rpc_method_handler( - servicer.GetDatasetSize, - request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, - response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, - ), - "CheckAvailability": grpc.unary_unary_rpc_method_handler( - servicer.CheckAvailability, - request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, - response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, - ), - "RegisterNewDataset": grpc.unary_unary_rpc_method_handler( - servicer.RegisterNewDataset, - request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, - response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, - ), - "GetCurrentTimestamp": grpc.unary_unary_rpc_method_handler( - servicer.GetCurrentTimestamp, - request_deserializer=storage__pb2.GetCurrentTimestampRequest.FromString, - response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, - ), - "DeleteDataset": grpc.unary_unary_rpc_method_handler( - servicer.DeleteDataset, - request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, - response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, - ), - "DeleteData": grpc.unary_unary_rpc_method_handler( - servicer.DeleteData, - request_deserializer=storage__pb2.DeleteDataRequest.FromString, - response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, - ), + 'Get': grpc.unary_stream_rpc_method_handler( + servicer.Get, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponse.SerializeToString, + ), + 'GetNL': grpc.unary_stream_rpc_method_handler( + servicer.GetNL, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponseNoLabels.SerializeToString, + ), + 'GetNewDataSince': grpc.unary_stream_rpc_method_handler( + servicer.GetNewDataSince, + request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, + response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, + ), + 'GetDataInInterval': grpc.unary_stream_rpc_method_handler( + servicer.GetDataInInterval, + request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, + response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, + ), + 'GetDataPerWorker': grpc.unary_stream_rpc_method_handler( + servicer.GetDataPerWorker, + request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, + response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, + ), + 'GetDatasetSize': grpc.unary_unary_rpc_method_handler( + servicer.GetDatasetSize, + request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, + response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, + ), + 'CheckAvailability': grpc.unary_unary_rpc_method_handler( + servicer.CheckAvailability, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, + ), + 'RegisterNewDataset': grpc.unary_unary_rpc_method_handler( + servicer.RegisterNewDataset, + request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, + response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, + ), + 'GetCurrentTimestamp': grpc.unary_unary_rpc_method_handler( + servicer.GetCurrentTimestamp, + request_deserializer=storage__pb2.GetCurrentTimestampRequest.FromString, + response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, + ), + 'DeleteDataset': grpc.unary_unary_rpc_method_handler( + servicer.DeleteDataset, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, + ), + 'DeleteData': grpc.unary_unary_rpc_method_handler( + servicer.DeleteData, + request_deserializer=storage__pb2.DeleteDataRequest.FromString, + response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler("modyn.storage.Storage", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler( + 'modyn.storage.Storage', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('modyn.storage.Storage', rpc_method_handlers) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class Storage(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Get( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def Get(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/Get", + '/modyn.storage.Storage/Get', storage__pb2.GetRequest.SerializeToString, storage__pb2.GetResponse.FromString, options, @@ -255,26 +254,50 @@ def Get( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) + + @staticmethod + def GetNL(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/modyn.storage.Storage/GetNL', + storage__pb2.GetRequest.SerializeToString, + storage__pb2.GetResponseNoLabels.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) @staticmethod - def GetNewDataSince( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetNewDataSince(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/GetNewDataSince", + '/modyn.storage.Storage/GetNewDataSince', storage__pb2.GetNewDataSinceRequest.SerializeToString, storage__pb2.GetNewDataSinceResponse.FromString, options, @@ -285,26 +308,23 @@ def GetNewDataSince( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetDataInInterval( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetDataInInterval(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/GetDataInInterval", + '/modyn.storage.Storage/GetDataInInterval', storage__pb2.GetDataInIntervalRequest.SerializeToString, storage__pb2.GetDataInIntervalResponse.FromString, options, @@ -315,26 +335,23 @@ def GetDataInInterval( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetDataPerWorker( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetDataPerWorker(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/GetDataPerWorker", + '/modyn.storage.Storage/GetDataPerWorker', storage__pb2.GetDataPerWorkerRequest.SerializeToString, storage__pb2.GetDataPerWorkerResponse.FromString, options, @@ -345,26 +362,23 @@ def GetDataPerWorker( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetDatasetSize( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetDatasetSize(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/GetDatasetSize", + '/modyn.storage.Storage/GetDatasetSize', storage__pb2.GetDatasetSizeRequest.SerializeToString, storage__pb2.GetDatasetSizeResponse.FromString, options, @@ -375,26 +389,23 @@ def GetDatasetSize( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def CheckAvailability( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def CheckAvailability(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/CheckAvailability", + '/modyn.storage.Storage/CheckAvailability', storage__pb2.DatasetAvailableRequest.SerializeToString, storage__pb2.DatasetAvailableResponse.FromString, options, @@ -405,26 +416,23 @@ def CheckAvailability( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def RegisterNewDataset( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def RegisterNewDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/RegisterNewDataset", + '/modyn.storage.Storage/RegisterNewDataset', storage__pb2.RegisterNewDatasetRequest.SerializeToString, storage__pb2.RegisterNewDatasetResponse.FromString, options, @@ -435,26 +443,23 @@ def RegisterNewDataset( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetCurrentTimestamp( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetCurrentTimestamp(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/GetCurrentTimestamp", + '/modyn.storage.Storage/GetCurrentTimestamp', storage__pb2.GetCurrentTimestampRequest.SerializeToString, storage__pb2.GetCurrentTimestampResponse.FromString, options, @@ -465,26 +470,23 @@ def GetCurrentTimestamp( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def DeleteDataset( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def DeleteDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/DeleteDataset", + '/modyn.storage.Storage/DeleteDataset', storage__pb2.DatasetAvailableRequest.SerializeToString, storage__pb2.DeleteDatasetResponse.FromString, options, @@ -495,26 +497,23 @@ def DeleteDataset( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def DeleteData( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def DeleteData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/DeleteData", + '/modyn.storage.Storage/DeleteData', storage__pb2.DeleteDataRequest.SerializeToString, storage__pb2.DeleteDataResponse.FromString, options, @@ -525,5 +524,4 @@ def DeleteData( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) diff --git a/modyn/storage/src/internal/database/grpc2/storage_pb2.py b/modyn/storage/src/internal/database/grpc2/storage_pb2.py new file mode 100644 index 000000000..18ff61e95 --- /dev/null +++ b/modyn/storage/src/internal/database/grpc2/storage_pb2.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: storage.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rstorage.proto\x12\rmodyn.storage\"F\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x16\n\x0einclude_labels\x18\x03 \x01(\x08\"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"4\n\x13GetResponseNoLabels\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"\x1c\n\x1aGetCurrentTimestampRequest\"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03\"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03\"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03\"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03\"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xae\x08\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse\"\x00\x30\x01\x12J\n\x05GetNL\x12\x19.modyn.storage.GetRequest\x1a\".modyn.storage.GetResponseNoLabels\"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse\"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse\"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse\"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse\"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse\"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse\"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse\"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse\"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse\"\x00\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'storage_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _GETREQUEST._serialized_start=32 + _GETREQUEST._serialized_end=102 + _GETRESPONSE._serialized_start=104 + _GETRESPONSE._serialized_end=164 + _GETRESPONSENOLABELS._serialized_start=166 + _GETRESPONSENOLABELS._serialized_end=218 + _GETCURRENTTIMESTAMPREQUEST._serialized_start=220 + _GETCURRENTTIMESTAMPREQUEST._serialized_end=248 + _GETNEWDATASINCEREQUEST._serialized_start=250 + _GETNEWDATASINCEREQUEST._serialized_end=313 + _GETNEWDATASINCERESPONSE._serialized_start=315 + _GETNEWDATASINCERESPONSE._serialized_end=390 + _GETDATAININTERVALREQUEST._serialized_start=392 + _GETDATAININTERVALREQUEST._serialized_end=486 + _GETDATAININTERVALRESPONSE._serialized_start=488 + _GETDATAININTERVALRESPONSE._serialized_end=565 + _GETDATAPERWORKERREQUEST._serialized_start=568 + _GETDATAPERWORKERREQUEST._serialized_end=751 + _GETDATAPERWORKERRESPONSE._serialized_start=753 + _GETDATAPERWORKERRESPONSE._serialized_end=793 + _GETDATASETSIZEREQUEST._serialized_start=796 + _GETDATASETSIZEREQUEST._serialized_end=935 + _GETDATASETSIZERESPONSE._serialized_start=937 + _GETDATASETSIZERESPONSE._serialized_end=996 + _DATASETAVAILABLEREQUEST._serialized_start=998 + _DATASETAVAILABLEREQUEST._serialized_end=1043 + _DATASETAVAILABLERESPONSE._serialized_start=1045 + _DATASETAVAILABLERESPONSE._serialized_end=1090 + _REGISTERNEWDATASETREQUEST._serialized_start=1093 + _REGISTERNEWDATASETREQUEST._serialized_end=1348 + _REGISTERNEWDATASETRESPONSE._serialized_start=1350 + _REGISTERNEWDATASETRESPONSE._serialized_end=1395 + _GETCURRENTTIMESTAMPRESPONSE._serialized_start=1397 + _GETCURRENTTIMESTAMPRESPONSE._serialized_end=1445 + _DELETEDATASETRESPONSE._serialized_start=1447 + _DELETEDATASETRESPONSE._serialized_end=1487 + _DELETEDATAREQUEST._serialized_start=1489 + _DELETEDATAREQUEST._serialized_end=1542 + _DELETEDATARESPONSE._serialized_start=1544 + _DELETEDATARESPONSE._serialized_end=1581 + _STORAGE._serialized_start=1584 + _STORAGE._serialized_end=2654 +# @@protoc_insertion_point(module_scope) diff --git a/modyn/storage/src/internal/database/grpc2/storage_pb2.pyi b/modyn/storage/src/internal/database/grpc2/storage_pb2.pyi new file mode 100644 index 000000000..a6c340801 --- /dev/null +++ b/modyn/storage/src/internal/database/grpc2/storage_pb2.pyi @@ -0,0 +1,425 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class GetRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + INCLUDE_LABELS_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + include_labels: builtins.bool + """Added this line""" + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + dataset_id: builtins.str = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + include_labels: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "include_labels", b"include_labels", "keys", b"keys"]) -> None: ... + +global___GetRequest = GetRequest + +@typing.final +class GetResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SAMPLES_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + @property + def samples(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + samples: collections.abc.Iterable[builtins.bytes] | None = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + labels: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"]) -> None: ... + +global___GetResponse = GetResponse + +@typing.final +class GetResponseNoLabels(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SAMPLES_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + @property + def samples(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + samples: collections.abc.Iterable[builtins.bytes] | None = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "samples", b"samples"]) -> None: ... + +global___GetResponseNoLabels = GetResponseNoLabels + +@typing.final +class GetCurrentTimestampRequest(google.protobuf.message.Message): + """https://github.com/grpc/grpc/issues/15937""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___GetCurrentTimestampRequest = GetCurrentTimestampRequest + +@typing.final +class GetNewDataSinceRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + timestamp: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"]) -> None: ... + +global___GetNewDataSinceRequest = GetNewDataSinceRequest + +@typing.final +class GetNewDataSinceResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + TIMESTAMPS_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def timestamps(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.int] | None = ..., + timestamps: collections.abc.Iterable[builtins.int] | None = ..., + labels: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... + +global___GetNewDataSinceResponse = GetNewDataSinceResponse + +@typing.final +class GetDataInIntervalRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + START_TIMESTAMP_FIELD_NUMBER: builtins.int + END_TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + start_timestamp: builtins.int + end_timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + start_timestamp: builtins.int = ..., + end_timestamp: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... + +global___GetDataInIntervalRequest = GetDataInIntervalRequest + +@typing.final +class GetDataInIntervalResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + TIMESTAMPS_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def timestamps(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.int] | None = ..., + timestamps: collections.abc.Iterable[builtins.int] | None = ..., + labels: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... + +global___GetDataInIntervalResponse = GetDataInIntervalResponse + +@typing.final +class GetDataPerWorkerRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + WORKER_ID_FIELD_NUMBER: builtins.int + TOTAL_WORKERS_FIELD_NUMBER: builtins.int + START_TIMESTAMP_FIELD_NUMBER: builtins.int + END_TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + worker_id: builtins.int + total_workers: builtins.int + start_timestamp: builtins.int + """value unset means no limit + start_timestamp is inclusive, end_timestamp is exclusive + """ + end_timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + worker_id: builtins.int = ..., + total_workers: builtins.int = ..., + start_timestamp: builtins.int | None = ..., + end_timestamp: builtins.int | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp", "total_workers", b"total_workers", "worker_id", b"worker_id"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... + +global___GetDataPerWorkerRequest = GetDataPerWorkerRequest + +@typing.final +class GetDataPerWorkerResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys"]) -> None: ... + +global___GetDataPerWorkerResponse = GetDataPerWorkerResponse + +@typing.final +class GetDatasetSizeRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + START_TIMESTAMP_FIELD_NUMBER: builtins.int + END_TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + start_timestamp: builtins.int + """value unset means no limit + start_timestamp is inclusive, end_timestamp is exclusive + """ + end_timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + start_timestamp: builtins.int | None = ..., + end_timestamp: builtins.int | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... + +global___GetDatasetSizeRequest = GetDatasetSizeRequest + +@typing.final +class GetDatasetSizeResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + NUM_KEYS_FIELD_NUMBER: builtins.int + success: builtins.bool + num_keys: builtins.int + def __init__( + self, + *, + success: builtins.bool = ..., + num_keys: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["num_keys", b"num_keys", "success", b"success"]) -> None: ... + +global___GetDatasetSizeResponse = GetDatasetSizeResponse + +@typing.final +class DatasetAvailableRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + def __init__( + self, + *, + dataset_id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id"]) -> None: ... + +global___DatasetAvailableRequest = DatasetAvailableRequest + +@typing.final +class DatasetAvailableResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + AVAILABLE_FIELD_NUMBER: builtins.int + available: builtins.bool + def __init__( + self, + *, + available: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["available", b"available"]) -> None: ... + +global___DatasetAvailableResponse = DatasetAvailableResponse + +@typing.final +class RegisterNewDatasetRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + FILESYSTEM_WRAPPER_TYPE_FIELD_NUMBER: builtins.int + FILE_WRAPPER_TYPE_FIELD_NUMBER: builtins.int + DESCRIPTION_FIELD_NUMBER: builtins.int + BASE_PATH_FIELD_NUMBER: builtins.int + VERSION_FIELD_NUMBER: builtins.int + FILE_WRAPPER_CONFIG_FIELD_NUMBER: builtins.int + IGNORE_LAST_TIMESTAMP_FIELD_NUMBER: builtins.int + FILE_WATCHER_INTERVAL_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + filesystem_wrapper_type: builtins.str + file_wrapper_type: builtins.str + description: builtins.str + base_path: builtins.str + version: builtins.str + file_wrapper_config: builtins.str + ignore_last_timestamp: builtins.bool + file_watcher_interval: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + filesystem_wrapper_type: builtins.str = ..., + file_wrapper_type: builtins.str = ..., + description: builtins.str = ..., + base_path: builtins.str = ..., + version: builtins.str = ..., + file_wrapper_config: builtins.str = ..., + ignore_last_timestamp: builtins.bool = ..., + file_watcher_interval: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["base_path", b"base_path", "dataset_id", b"dataset_id", "description", b"description", "file_watcher_interval", b"file_watcher_interval", "file_wrapper_config", b"file_wrapper_config", "file_wrapper_type", b"file_wrapper_type", "filesystem_wrapper_type", b"filesystem_wrapper_type", "ignore_last_timestamp", b"ignore_last_timestamp", "version", b"version"]) -> None: ... + +global___RegisterNewDatasetRequest = RegisterNewDatasetRequest + +@typing.final +class RegisterNewDatasetResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__( + self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + +global___RegisterNewDatasetResponse = RegisterNewDatasetResponse + +@typing.final +class GetCurrentTimestampResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TIMESTAMP_FIELD_NUMBER: builtins.int + timestamp: builtins.int + def __init__( + self, + *, + timestamp: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["timestamp", b"timestamp"]) -> None: ... + +global___GetCurrentTimestampResponse = GetCurrentTimestampResponse + +@typing.final +class DeleteDatasetResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__( + self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + +global___DeleteDatasetResponse = DeleteDatasetResponse + +@typing.final +class DeleteDataRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + dataset_id: builtins.str = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + +global___DeleteDataRequest = DeleteDataRequest + +@typing.final +class DeleteDataResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__( + self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + +global___DeleteDataResponse = DeleteDataResponse diff --git a/modyn/storage/src/internal/database/grpc2/storage_pb2_grpc.py b/modyn/storage/src/internal/database/grpc2/storage_pb2_grpc.py new file mode 100644 index 000000000..60712130f --- /dev/null +++ b/modyn/storage/src/internal/database/grpc2/storage_pb2_grpc.py @@ -0,0 +1,396 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import storage_pb2 as storage__pb2 + + +class StorageStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Get = channel.unary_stream( + '/modyn.storage.Storage/Get', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponse.FromString, + ) + self.GetNL = channel.unary_stream( + '/modyn.storage.Storage/GetNL', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponseNoLabels.FromString, + ) + self.GetNewDataSince = channel.unary_stream( + '/modyn.storage.Storage/GetNewDataSince', + request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, + response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, + ) + self.GetDataInInterval = channel.unary_stream( + '/modyn.storage.Storage/GetDataInInterval', + request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, + ) + self.GetDataPerWorker = channel.unary_stream( + '/modyn.storage.Storage/GetDataPerWorker', + request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, + ) + self.GetDatasetSize = channel.unary_unary( + '/modyn.storage.Storage/GetDatasetSize', + request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, + response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, + ) + self.CheckAvailability = channel.unary_unary( + '/modyn.storage.Storage/CheckAvailability', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, + ) + self.RegisterNewDataset = channel.unary_unary( + '/modyn.storage.Storage/RegisterNewDataset', + request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, + response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, + ) + self.GetCurrentTimestamp = channel.unary_unary( + '/modyn.storage.Storage/GetCurrentTimestamp', + request_serializer=storage__pb2.GetCurrentTimestampRequest.SerializeToString, + response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, + ) + self.DeleteDataset = channel.unary_unary( + '/modyn.storage.Storage/DeleteDataset', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, + ) + self.DeleteData = channel.unary_unary( + '/modyn.storage.Storage/DeleteData', + request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDataResponse.FromString, + ) + + +class StorageServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Get(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetNL(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetNewDataSince(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDataInInterval(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDataPerWorker(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDatasetSize(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CheckAvailability(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RegisterNewDataset(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetCurrentTimestamp(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteDataset(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteData(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_StorageServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Get': grpc.unary_stream_rpc_method_handler( + servicer.Get, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponse.SerializeToString, + ), + 'GetNL': grpc.unary_stream_rpc_method_handler( + servicer.GetNL, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponseNoLabels.SerializeToString, + ), + 'GetNewDataSince': grpc.unary_stream_rpc_method_handler( + servicer.GetNewDataSince, + request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, + response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, + ), + 'GetDataInInterval': grpc.unary_stream_rpc_method_handler( + servicer.GetDataInInterval, + request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, + response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, + ), + 'GetDataPerWorker': grpc.unary_stream_rpc_method_handler( + servicer.GetDataPerWorker, + request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, + response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, + ), + 'GetDatasetSize': grpc.unary_unary_rpc_method_handler( + servicer.GetDatasetSize, + request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, + response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, + ), + 'CheckAvailability': grpc.unary_unary_rpc_method_handler( + servicer.CheckAvailability, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, + ), + 'RegisterNewDataset': grpc.unary_unary_rpc_method_handler( + servicer.RegisterNewDataset, + request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, + response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, + ), + 'GetCurrentTimestamp': grpc.unary_unary_rpc_method_handler( + servicer.GetCurrentTimestamp, + request_deserializer=storage__pb2.GetCurrentTimestampRequest.FromString, + response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, + ), + 'DeleteDataset': grpc.unary_unary_rpc_method_handler( + servicer.DeleteDataset, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, + ), + 'DeleteData': grpc.unary_unary_rpc_method_handler( + servicer.DeleteData, + request_deserializer=storage__pb2.DeleteDataRequest.FromString, + response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'modyn.storage.Storage', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Storage(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Get(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/Get', + storage__pb2.GetRequest.SerializeToString, + storage__pb2.GetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetNL(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetNL', + storage__pb2.GetRequest.SerializeToString, + storage__pb2.GetResponseNoLabels.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetNewDataSince(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetNewDataSince', + storage__pb2.GetNewDataSinceRequest.SerializeToString, + storage__pb2.GetNewDataSinceResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDataInInterval(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetDataInInterval', + storage__pb2.GetDataInIntervalRequest.SerializeToString, + storage__pb2.GetDataInIntervalResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDataPerWorker(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetDataPerWorker', + storage__pb2.GetDataPerWorkerRequest.SerializeToString, + storage__pb2.GetDataPerWorkerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDatasetSize(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/GetDatasetSize', + storage__pb2.GetDatasetSizeRequest.SerializeToString, + storage__pb2.GetDatasetSizeResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CheckAvailability(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/CheckAvailability', + storage__pb2.DatasetAvailableRequest.SerializeToString, + storage__pb2.DatasetAvailableResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def RegisterNewDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/RegisterNewDataset', + storage__pb2.RegisterNewDatasetRequest.SerializeToString, + storage__pb2.RegisterNewDatasetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetCurrentTimestamp(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/GetCurrentTimestamp', + storage__pb2.GetCurrentTimestampRequest.SerializeToString, + storage__pb2.GetCurrentTimestampResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/DeleteDataset', + storage__pb2.DatasetAvailableRequest.SerializeToString, + storage__pb2.DeleteDatasetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/DeleteData', + storage__pb2.DeleteDataRequest.SerializeToString, + storage__pb2.DeleteDataResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/modyn/storage/src/internal/grpc/storage_service_impl.cpp b/modyn/storage/src/internal/grpc/storage_service_impl.cpp index 4640f2593..105f48420 100644 --- a/modyn/storage/src/internal/grpc/storage_service_impl.cpp +++ b/modyn/storage/src/internal/grpc/storage_service_impl.cpp @@ -17,6 +17,11 @@ Status StorageServiceImpl::Get( // NOLINT readability-identifier-naming ServerWriter* writer) { return Get_Impl>(context, request, writer); } +Status StorageServiceImpl::GetNL( // NOLINT readability-identifier-naming + ServerContext* context, const modyn::storage::GetRequest* request, + ServerWriter* writer) { + return Get_Impl_NL>(context, request, writer); +} Status StorageServiceImpl::GetNewDataSince( // NOLINT readability-identifier-naming ServerContext* context, const modyn::storage::GetNewDataSinceRequest* request, diff --git a/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py b/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py index 3f38d0a7e..e3790451c 100644 --- a/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py +++ b/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py @@ -30,7 +30,6 @@ def __init__(self, supervisor: Supervisor, modyn_config: dict) -> None: def start_pipeline(self, request: StartPipelineRequest, context: grpc.ServicerContext) -> PipelineResponse: tid = threading.get_native_id() pid = os.getpid() - logger.info(f"[{pid}][{tid}]: Starting pipeline with request - {request}") start_replay_at: int | None = None @@ -44,7 +43,6 @@ def start_pipeline(self, request: StartPipelineRequest, context: grpc.ServicerCo maximum_triggers = request.maximum_triggers pipeline_config = json.loads(request.pipeline_config.value) - msg = self._supervisor.start_pipeline( pipeline_config, self.modyn_config["supervisor"]["eval_directory"], diff --git a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py index 75333d46f..e4b641135 100644 --- a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py +++ b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py @@ -84,6 +84,7 @@ def register_tracking_info(self, tracking_dfs: dict[str, pd.DataFrame], dataset_ tracking_dfs: A dictionary of dataframes containing tracking information. dataset_end_time: Timestamp of the last sample in the dataset. """ + print(PipelineStage.STORE_TRAINED_MODEL.name) assert tracking_dfs.get(PipelineStage.HANDLE_SINGLE_TRIGGER.name) is not None assert tracking_dfs.get(PipelineStage.STORE_TRAINED_MODEL.name) is not None self.context = AfterPipelineEvalContext(tracking_dfs=tracking_dfs, dataset_end_time=dataset_end_time) diff --git a/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py b/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py index f3f89bc41..dec828a3d 100644 --- a/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py +++ b/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py @@ -157,6 +157,7 @@ def get_start_training_request(checkpoint_path=""): pretrained_model_id=-1, lr_scheduler=JsonString(value=json.dumps({})), grad_scaler_configuration=JsonString(value=json.dumps({})), + generative=False ) diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 48856aaa8..c87c41165 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -12,7 +12,15 @@ from typing import Any, cast import grpc -from tenacity import Retrying, after_log, before_log, retry, stop_after_attempt, wait_random_exponential +from tenacity import ( + Retrying, + after_log, + before_log, + retry, + stop_after_attempt, + wait_random_exponential, +) + from torch.utils.data import IterableDataset, get_worker_info from torchvision import transforms @@ -20,9 +28,13 @@ from modyn.storage.internal.grpc.generated.storage_pb2 import ( # pylint: disable=no-name-in-module GetRequest, GetResponse, + GetResponseNoLabels, ) from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub -from modyn.trainer_server.internal.dataset.key_sources import AbstractKeySource, SelectorKeySource +from modyn.trainer_server.internal.dataset.key_sources import ( + AbstractKeySource, + SelectorKeySource, +) from modyn.utils import ( BYTES_PARSER_FUNC_NAME, deserialize_function, @@ -53,7 +65,9 @@ def __init__( shuffle: bool, tokenizer: str | None, log_path: pathlib.Path | None, + generative: bool = False, ): + self._generative = generative self._pipeline_id = pipeline_id self._trigger_id = trigger_id self._training_id = training_id @@ -86,7 +100,9 @@ def __init__( self._pref_started: dict[int, bool] = {} self._thread_data_container: dict[int, dict[str, Any]] = {} self._partition_locks: dict[int, threading.Lock] = {} - self._partition_signals: dict[int, threading.Condition] = {} # Should use the lock out of partition_locks + self._partition_signals: dict[int, threading.Condition] = ( + {} + ) # Should use the lock out of partition_locks self._partition_valid_until: dict[int, int] = {} self._partition_valid: dict[int, bool] = {} self._next_partition_to_fetch = 0 @@ -96,13 +112,16 @@ def __init__( self._shuffled_partition_indices: list[int] = [] if log_path is None: - logger.warning("Did not provide log path for OnlineDataset - logging disabled.") + logger.warning( + "Did not provide log path for OnlineDataset - logging disabled." + ) # tokenizer for NLP tasks self._tokenizer = None self._tokenizer_name = tokenizer if tokenizer is not None: - self._tokenizer = instantiate_class("modyn.models.tokenizers", tokenizer) + self._tokenizer = instantiate_class( + "modyn.models.tokenizers", tokenizer) logger.debug("Initialized OnlineDataset.") @@ -124,7 +143,9 @@ def _setup_composed_transform(self) -> None: self._transform = transforms.Compose(self._transform_list) def _init_transforms(self) -> None: - self._bytes_parser_function = deserialize_function(self._bytes_parser, BYTES_PARSER_FUNC_NAME) + self._bytes_parser_function = deserialize_function( + self._bytes_parser, BYTES_PARSER_FUNC_NAME + ) self._transform = self._bytes_parser_function self._setup_composed_transform() @@ -136,7 +157,9 @@ def _init_transforms(self) -> None: reraise=True, ) def _init_grpc(self, worker_id: int | None = None) -> None: # pragma: no cover - self._storage_channel = grpc.insecure_channel(self._storage_address, options=grpc_common_config()) + self._storage_channel = grpc.insecure_channel( + self._storage_address, options=grpc_common_config() + ) if grpc_connection_established(self._storage_channel): self._storagestub = StorageStub(self._storage_channel) return @@ -147,29 +170,38 @@ def _init_grpc(self, worker_id: int | None = None) -> None: # pragma: no cover def _silence_pil(self) -> None: # pragma: no cover pil_logger = logging.getLogger("PIL") - pil_logger.setLevel(logging.INFO) # by default, PIL on DEBUG spams the console + # by default, PIL on DEBUG spams the console + pil_logger.setLevel(logging.INFO) def _info(self, msg: str, worker_id: int | None) -> None: # pragma: no cover - logger.info(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") + logger.info( + f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}" + ) def _debug(self, msg: str, worker_id: int | None) -> None: # pragma: no cover - logger.debug(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") + logger.debug( + f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}" + ) def _get_data_from_storage( self, selector_keys: list[int], worker_id: int | None = None ) -> Iterator[tuple[list[int], list[bytes], list[int], int]]: processed_keys: set[int] | list[int] = [] has_failed = False - for attempt in Retrying( - stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + reraise=True, ): with attempt: try: - req = GetRequest(dataset_id=self._dataset_id, keys=selector_keys) + new_keys: list[int] + new_samples: list[bytes] + response:GetResponse + req = GetRequest( + dataset_id=self._dataset_id, keys=selector_keys + ) stopw = Stopwatch() - - response: GetResponse stopw.start("ResponseTime", overwrite=True) for response in self._storagestub.Get(req): response_time = stopw.stop("ResponseTime") @@ -177,32 +209,115 @@ def _get_data_from_storage( if not has_failed: assert isinstance(processed_keys, list) processed_keys.extend(keys) - yield keys, list(response.samples), list(response.labels), response_time + yield keys, list(response.samples), list( + response.labels + ), response_time else: # If we have failed, we need to filter out yielded samples # Note that the returned order by storage is non-deterministic assert isinstance(processed_keys, set) - new_keys: list[int] = [key for key in keys if key not in processed_keys] - new_samples: list[bytes] = [ - sample for key, sample in zip(keys, response.samples) if key not in processed_keys + new_keys = [ + key for key in keys if key not in processed_keys + ] + new_samples = [ + sample + for key, sample in zip(keys, response.samples) + if key not in processed_keys ] new_labels: list[int] = [ - label for key, label in zip(keys, response.labels) if key not in processed_keys + label + for key, label in zip(keys, response.labels) + if key not in processed_keys ] processed_keys.update(keys) yield new_keys, new_samples, new_labels, response_time stopw.start("ResponseTime", overwrite=True) + except ( + grpc.RpcError + ) as e: # We catch and reraise to reconnect to rpc and do logging + has_failed = True + # Convert processed keys to set on first failure + processed_keys = ( + set(processed_keys) + if isinstance(processed_keys, list) + else processed_keys + ) + self._info( + "gRPC error occurred, processed_keys = " + + f"{processed_keys}\n{e.code()} - {e.details()}", + worker_id, + ) + self._info(f"Stringified exception: {str(e)}", worker_id) + self._info( + f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", + worker_id, + ) + self._init_grpc(worker_id=worker_id) + raise e + + def _get_data_from_storage_gen( + self, selector_keys: list[int], worker_id: int | None = None + ) -> Iterator[tuple[list[int], list[bytes], int]]: + processed_keys: set[int] | list[int] = [] + has_failed = False + for attempt in Retrying( + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + reraise=True, + ): + with attempt: + try: + new_keys: list[int] + new_samples: list[bytes] + response: GetResponseNoLabels + req = GetRequest( + dataset_id=self._dataset_id, keys=selector_keys + ) + stopw = Stopwatch() + stopw.start("ResponseTime", overwrite=True) + for response in self._storagestub.GetNL(req): + response_time = stopw.stop("ResponseTime") + keys = list(response.keys) + if not has_failed: + assert isinstance(processed_keys, list) + processed_keys.extend(keys) + yield keys, list(response.samples), response_time + else: # If we have failed, we need to filter out yielded samples + # Note that the returned order by storage is non-deterministic + assert isinstance(processed_keys, set) + new_keys = [ + key for key in keys if key not in processed_keys + ] + new_samples = [ + sample + for key, sample in zip(keys, response.samples) + if key not in processed_keys + ] - except grpc.RpcError as e: # We catch and reraise to reconnect to rpc and do logging + processed_keys.update(keys) + yield new_keys, new_samples, response_time + + stopw.start("ResponseTime", overwrite=True) + except ( + grpc.RpcError + ) as e: # We catch and reraise to reconnect to rpc and do logging has_failed = True # Convert processed keys to set on first failure - processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys + processed_keys = ( + set(processed_keys) + if isinstance(processed_keys, list) + else processed_keys + ) self._info( - "gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}", + "gRPC error occurred, processed_keys = " + + f"{processed_keys}\n{e.code()} - {e.details()}", worker_id, ) self._info(f"Stringified exception: {str(e)}", worker_id) - self._info(f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", worker_id) + self._info( + f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", + worker_id, + ) self._init_grpc(worker_id=worker_id) raise e @@ -223,37 +338,77 @@ def _get_data( get_data_log = {} self._sw.start(f"GetKeysAndWeightsPart{partition_id}", overwrite=True) keys, weights = self._key_source.get_keys_and_weights( - worker_id, shuffled_partition_id if shuffled_partition_id is not None else partition_id + worker_id, + ( + shuffled_partition_id + if shuffled_partition_id is not None + else partition_id + ), + ) + get_data_log["get_keys_and_weights"] = self._sw.stop( + f"GetKeysAndWeightsPart{partition_id}" ) - get_data_log["get_keys_and_weights"] = self._sw.stop(f"GetKeysAndWeightsPart{partition_id}") get_data_log["num_items"] = len(keys) self._info("Getting data from storage", worker_id) self._sw.start(f"GetDataPart{partition_id}", overwrite=True) all_response_times = [] + key_weight_map = ( + {key: weights[idx] for idx, key in enumerate(keys)} + if weights is not None + else None + ) + if self._generative: + for data_tuple_gen in self._get_data_from_storage_gen(keys, worker_id=worker_id): + stor_keys, data, response_time = data_tuple_gen + + all_response_times.append(response_time) + num_items = len(stor_keys) + with ( + partition_locks[partition_id] + if partition_locks is not None + else contextlib.suppress() + ): + data_container["data"].extend(data) + data_container["keys"].extend(stor_keys) + data_container["weights"].extend( + [cast(float | None, key_weight_map[key]) + for key in stor_keys] + if key_weight_map is not None + else [None for _ in range(len(stor_keys))] + ) + if partition_valid_until is not None: + partition_valid_until[partition_id] += num_items - key_weight_map = {key: weights[idx] for idx, key in enumerate(keys)} if weights is not None else None - - for data_tuple in self._get_data_from_storage(keys, worker_id=worker_id): - stor_keys, data, labels, response_time = data_tuple - - all_response_times.append(response_time) - num_items = len(stor_keys) - with partition_locks[partition_id] if partition_locks is not None else contextlib.suppress(): - data_container["data"].extend(data) - data_container["keys"].extend(stor_keys) - data_container["labels"].extend(labels) - data_container["weights"].extend( - [cast(float | None, key_weight_map[key]) for key in stor_keys] - if key_weight_map is not None - else [None for _ in range(len(stor_keys))] - ) - if partition_valid_until is not None: - partition_valid_until[partition_id] += num_items + if partition_signals is not None: + with partition_signals[partition_id]: + partition_signals[partition_id].notify_all() + else: + for data_tuple in self._get_data_from_storage(keys, worker_id=worker_id): + stor_keys, data, labels, response_time = data_tuple + + all_response_times.append(response_time) + num_items = len(stor_keys) + with ( + partition_locks[partition_id] + if partition_locks is not None + else contextlib.suppress() + ): + data_container["data"].extend(data) + data_container["keys"].extend(stor_keys) + data_container["labels"].extend(labels) + data_container["weights"].extend( + [cast(float | None, key_weight_map[key]) + for key in stor_keys] + if key_weight_map is not None + else [None for _ in range(len(stor_keys))] + ) + if partition_valid_until is not None: + partition_valid_until[partition_id] += num_items - if partition_signals is not None: - with partition_signals[partition_id]: - partition_signals[partition_id].notify_all() + if partition_signals is not None: + with partition_signals[partition_id]: + partition_signals[partition_id].notify_all() get_data_log["get_data"] = self._sw.stop(f"GetDataPart{partition_id}") get_data_log["response_times"] = all_response_times @@ -281,6 +436,18 @@ def _get_transformed_data_tuple( return key, transformed_sample, label, weight return key, transformed_sample, label + def _get_transformed_data_tuple_gen( + self, key: int, sample: memoryview, weight: float | None + ) -> tuple | None: + assert self._uses_weights is not None + self._sw.start("transform", resume=True) + # mypy complains here because _transform has unknown type, which is ok + transformed_sample = self._transform(sample) # type: ignore + self._sw.stop("transform") + if self._uses_weights: + return key, transformed_sample, weight + return key, transformed_sample + def end_of_trigger_cleaning(self) -> None: self._key_source.end_of_trigger_cleaning() @@ -297,14 +464,22 @@ def _persist_log(self, worker_id: int) -> None: log_file = f"{self._log_path / str(worker_id)}.log" self._log["transform"] = self._sw.measurements.get("transform", 0) - self._log["wait_for_later_partitions"] = self._sw.measurements.get("wait_for_later_partitions", 0) - self._log["wait_for_initial_partition"] = self._sw.measurements.get("wait_for_initial_partition", 0) + self._log["wait_for_later_partitions"] = self._sw.measurements.get( + "wait_for_later_partitions", 0 + ) + self._log["wait_for_initial_partition"] = self._sw.measurements.get( + "wait_for_initial_partition", 0 + ) with open(log_file, "w", encoding="utf-8") as logfile: json.dump(self._log, logfile) def _clear_partition(self, partition_id: int) -> None: - with self._partition_locks[partition_id] if self._partition_locks is not None else contextlib.suppress(): + with ( + self._partition_locks[partition_id] + if self._partition_locks is not None + else contextlib.suppress() + ): self._partition_valid[partition_id] = False self._partition_valid_until[partition_id] = -1 del self._thread_data_container[partition_id] @@ -312,35 +487,51 @@ def _clear_partition(self, partition_id: int) -> None: if "PYTEST_CURRENT_TEST" not in os.environ: gc.collect() - def _prefetch_partition(self, worker_id: int, maybe_continue: bool = False) -> None: + def _prefetch_partition( + self, worker_id: int, maybe_continue: bool = False + ) -> None: assert self._start_prefetch_lock is not None with self._start_prefetch_lock: - if self._num_prefetched_partitions < 1 or self._next_partition_to_fetch >= self._num_partitions: + if ( + self._num_prefetched_partitions < 1 + or self._next_partition_to_fetch >= self._num_partitions + ): return # Prefetching disabled or nothing more to prefetch - if maybe_continue and self._launched_prefetches >= self._num_prefetched_partitions: + if ( + maybe_continue + and self._launched_prefetches >= self._num_prefetched_partitions + ): return # Two callbacks started to prefetch basically at the same time if maybe_continue: # Do this as early as possible to avoid running into the "problem" above frequently self._launched_prefetches += 1 - assert self._next_partition_to_fetch >= 0 assert ( self._next_partition_to_fetch not in self._data_threads ), f"Prefetching for partition {self._next_partition_to_fetch} has already been started" - - self._thread_data_container[self._next_partition_to_fetch] = { - "data": [], - "keys": [], - "labels": [], - "weights": [], - } + if not self._generative: + self._thread_data_container[self._next_partition_to_fetch] = { + "data": [], + "keys": [], + "labels": [], + "weights": [], + } + else: + self._thread_data_container[self._next_partition_to_fetch] = { + "data": [], + "keys": [], + "weights": [], + } self._partition_valid[self._next_partition_to_fetch] = False self._partition_valid_until[self._next_partition_to_fetch] = -1 - self._partition_locks[self._next_partition_to_fetch] = threading.Lock() - self._partition_signals[self._next_partition_to_fetch] = threading.Condition( - self._partition_locks[self._next_partition_to_fetch] + self._partition_locks[self._next_partition_to_fetch] = threading.Lock( + ) + self._partition_signals[self._next_partition_to_fetch] = ( + threading.Condition( + self._partition_locks[self._next_partition_to_fetch] + ) ) callback = None @@ -367,7 +558,9 @@ def callback_func() -> None: # We implement shuffling on a partition level by mapping everything to increasing indices but actually load # different partition data. shuffle_partition_id = ( - self._shuffled_partition_indices[self._next_partition_to_fetch] if self._shuffle else None + self._shuffled_partition_indices[self._next_partition_to_fetch] + if self._shuffle + else None ) self._data_threads[self._next_partition_to_fetch] = threading.Thread( target=self._get_data, @@ -393,13 +586,37 @@ def _fetch_partition_noprefetch( self, worker_id: int, partition_id: int ) -> Iterator[tuple[int, memoryview, int, float | None]]: assert self._num_prefetched_partitions < 1 - container: dict[str, Any] = {"data": [], "keys": [], "labels": [], "weights": []} - shuffle_partition_id = self._shuffled_partition_indices[partition_id] if self._shuffle else None - self._get_data(container, worker_id, partition_id, None, None, None, None, None, shuffle_partition_id) - assert "data" in container and "labels" in container and "keys" in container and "weights" in container - + container: dict[str, Any] = { + "data": [], + "keys": [], + "labels": [], + "weights": [], + } + shuffle_partition_id = ( + self._shuffled_partition_indices[partition_id] if self._shuffle else None + ) + self._get_data( + container, + worker_id, + partition_id, + None, + None, + None, + None, + None, + shuffle_partition_id, + ) + assert ( + "data" in container + and "labels" in container + and "keys" in container + and "weights" in container + ) if self._shuffle: - self._shuffle_partition(partition_id, worker_id, container=container) + if not self._generative: + self._shuffle_partition(partition_id, worker_id,container=container) + else: + self._shuffle_partition_gen(partition_id, worker_id,container=container) for idx in range(len(container["keys"])): yield ( @@ -409,8 +626,45 @@ def _fetch_partition_noprefetch( container["weights"][idx], ) + def _fetch_partition_noprefetch_generative( # based on the one above + self, worker_id: int, partition_id: int + ) -> Iterator[tuple[int, memoryview, float | None]]: + assert self._num_prefetched_partitions < 1 + container: dict[str, Any] = {"data": [], "keys": [], "weights": []} + shuffle_partition_id = ( + self._shuffled_partition_indices[partition_id] if self._shuffle else None + ) + self._get_data( + container, + worker_id, + partition_id, + None, + None, + None, + None, + None, + shuffle_partition_id, + ) + assert "data" in container and "keys" in container and "weights" in container + + if self._shuffle: + if not self._generative: + self._shuffle_partition(partition_id, worker_id) + else: + self._shuffle_partition_gen(partition_id, worker_id) + + for idx in range(len(container["keys"])): + yield ( + container["keys"][idx], + memoryview(container["data"][idx]), + container["weights"][idx], + ) + def _is_partition_fetched(self, partition_id: int) -> bool: - if partition_id not in self._partition_locks or partition_id not in self._partition_valid: + if ( + partition_id not in self._partition_locks + or partition_id not in self._partition_valid + ): return False with self._partition_locks[partition_id]: @@ -422,24 +676,43 @@ def _partition_max_index(self, partition_id: int) -> int: def _get_partition_data( self, last_idx: int, max_idx: int, partition_id: int - ) -> Iterator[tuple[int, memoryview, int, float | None]]: - for idx in range(last_idx + 1, max_idx + 1): - yield ( - self._thread_data_container[partition_id]["keys"][idx], - memoryview(self._thread_data_container[partition_id]["data"][idx]), - self._thread_data_container[partition_id]["labels"][idx], - self._thread_data_container[partition_id]["weights"][idx], - ) + ) -> Iterator[tuple[int, memoryview, int, float | None]]| Iterator[tuple[int, memoryview, float | None]]: + if self._generative: + for idx in range(last_idx + 1, max_idx + 1): + yield ( + self._thread_data_container[partition_id]["keys"][idx], + memoryview( + self._thread_data_container[partition_id]["data"][idx]), + self._thread_data_container[partition_id]["weights"][idx], + ) + else: + for idx in range(last_idx + 1, max_idx + 1): + yield ( + self._thread_data_container[partition_id]["keys"][idx], + memoryview( + self._thread_data_container[partition_id]["data"][idx]), + self._thread_data_container[partition_id]["labels"][idx], + self._thread_data_container[partition_id]["weights"][idx], + ) def _wait_for_new_partition_data(self, partition_id: int) -> None: with self._partition_signals[partition_id]: - self._partition_signals[partition_id].wait(1) # In case we do not get woken up, we at most waste a second + self._partition_signals[partition_id].wait( + 1 + ) # In case we do not get woken up, we at most waste a second - def _shuffle_partition(self, partition_id: int, worker_id: int, container: dict | None = None) -> None: - assert container is not None or self._is_partition_fetched(partition_id) + def _shuffle_partition( + self, partition_id: int, worker_id: int, container: dict | None = None + ) -> None: + assert container is not None or self._is_partition_fetched( + partition_id) assert self._shuffle - container = container if container is not None else self._thread_data_container[partition_id] + container = ( + container + if container is not None + else self._thread_data_container[partition_id] + ) self._info(f"Shuffling partition {partition_id}", worker_id) @@ -459,9 +732,40 @@ def _shuffle_partition(self, partition_id: int, worker_id: int, container: dict self._info(f"Shuffled partition {partition_id}", worker_id) + def _shuffle_partition_gen( + self, partition_id: int, worker_id: int, container: dict | None = None + ) -> None: + assert container is not None or self._is_partition_fetched( + partition_id) + assert self._shuffle + + container = ( + container + if container is not None + else self._thread_data_container[partition_id] + ) + + self._info(f"Shuffling partition {partition_id}", worker_id) + + data_length = len(container["data"]) + indices = list(range(data_length)) + random.shuffle(indices) + + new_data = [container["data"][i] for i in indices] + new_keys = [container["keys"][i] for i in indices] + + new_weights = [container["weights"][i] for i in indices] + + container["data"] = new_data + container["keys"] = new_keys + + container["weights"] = new_weights + + self._info(f"Shuffled partition {partition_id}", worker_id) + def prefetched_partition_generator( self, worker_id: int, partition_id: int - ) -> Iterator[tuple[int, memoryview, int, float | None]]: + ) -> Iterator[tuple[int, memoryview, int, float | None]]|Iterator[tuple[int, memoryview, float | None]]: last_idx = -1 if not self._shuffle: # If we do not shuffle, we can emit data as soon as it streamed over @@ -470,14 +774,17 @@ def prefetched_partition_generator( max_idx = self._partition_max_index(partition_id) if max_idx <= last_idx: # No new data self._wait_for_new_partition_data(partition_id) - yield from self._get_partition_data(last_idx, max_idx, partition_id) last_idx = max_idx else: while not self._is_partition_fetched(partition_id): self._wait_for_new_partition_data(partition_id) + if not self._generative: + + self._shuffle_partition(partition_id, worker_id) + else: - self._shuffle_partition(partition_id, worker_id) + self._shuffle_partition_gen(partition_id, worker_id) # Yield potential remaining data (when not shuffling) or all data (when shuffling) self._info(f"Joining thread for partition {partition_id}", worker_id) @@ -504,21 +811,28 @@ def start_prefetching(self, worker_id: int) -> None: for _ in range(self._parallel_prefetch_requests): self._prefetch_partition(worker_id, True) - def all_partition_generator(self, worker_id: int) -> Iterator[tuple[int, memoryview, int, float | None]]: + def all_partition_generator( + self, worker_id: int + ) -> Iterator[tuple[int, memoryview, int, float | None]]|Iterator[tuple[int, memoryview,float | None]]: self.start_prefetching(worker_id) for partition_id in range(self._num_partitions): self._persist_log(worker_id) - if self._num_prefetched_partitions > 0: if partition_id < self._num_partitions - 1: # As we consume one partition, prefetch exactly one more partition self._prefetch_partition(worker_id, False) yield from self.prefetched_partition_generator(worker_id, partition_id) + else: - yield from self._fetch_partition_noprefetch(worker_id, partition_id) + if not self._generative: + yield from self._fetch_partition_noprefetch(worker_id, partition_id) + else: + yield from self._fetch_partition_noprefetch_generative( + worker_id, partition_id + ) # pylint: disable=too-many-locals, too-many-branches, too-many-statements def __iter__(self) -> Generator: @@ -531,7 +845,9 @@ def __iter__(self) -> Generator: if self._first_call: self._first_call = False - self._debug("This is the first run of iter, making gRPC connections.", worker_id) + self._debug( + "This is the first run of iter, making gRPC connections.", worker_id + ) # We have to initialize transformations and gRPC connections here to do it per dataloader worker, # otherwise the transformations/gRPC connections cannot be pickled for the new processes. self._init_transforms() @@ -560,9 +876,13 @@ def __iter__(self) -> Generator: self._num_partitions = self._key_source.get_num_data_partitions() if self._shuffle: - self._shuffled_partition_indices = list(range(0, self._num_partitions)) + self._shuffled_partition_indices = list( + range(0, self._num_partitions)) random.shuffle(self._shuffled_partition_indices) - self._info(f"Shuffled partitions into random order: {self._shuffled_partition_indices}", worker_id) + self._info( + f"Shuffled partitions into random order: {self._shuffled_partition_indices}", + worker_id, + ) self._info( f"Total number of partitions will be {self._num_partitions}.\n" @@ -573,10 +893,22 @@ def __iter__(self) -> Generator: assert self._log_lock is not None with self._log_lock: self._log["num_partitions"] = self._num_partitions - self._num_prefetched_partitions = min(self._num_prefetched_partitions, self._num_partitions) - - for data_tuple in self.all_partition_generator(worker_id): - if (transformed_tuple := self._get_transformed_data_tuple(*data_tuple)) is not None: - yield transformed_tuple + self._num_prefetched_partitions = min( + self._num_prefetched_partitions, self._num_partitions + ) + if self._generative: + for data_tuple in self.all_partition_generator(worker_id): + if ( + transformed_tuple := self._get_transformed_data_tuple_gen( + *data_tuple + ) + ) is not None: + yield transformed_tuple + else: + for data_tuple in self.all_partition_generator(worker_id): + if ( + transformed_tuple := self._get_transformed_data_tuple(*data_tuple) + ) is not None: + yield transformed_tuple self._persist_log(worker_id) diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py index c3a21053b..b3fedcdb5 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py @@ -1,56 +1,53 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: trainer_server.proto -# Protobuf Python Version: 5.26.1 """Generated protocol buffer code.""" - +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x14trainer_server.proto\x12\x07trainer"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05"\x19\n\x17TrainerAvailableRequest"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t"\xbb\x07\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x42\x07\n\x05_seedB\x0c\n\n_tokenizer"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse"\x00\x62\x06proto3' -) -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "trainer_server_pb2", _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals["_JSONSTRING"]._serialized_start = 33 - _globals["_JSONSTRING"]._serialized_end = 60 - _globals["_PYTHONSTRING"]._serialized_start = 62 - _globals["_PYTHONSTRING"]._serialized_end = 91 - _globals["_DATA"]._serialized_start = 93 - _globals["_DATA"]._serialized_end = 144 - _globals["_TRAINERAVAILABLEREQUEST"]._serialized_start = 146 - _globals["_TRAINERAVAILABLEREQUEST"]._serialized_end = 171 - _globals["_TRAINERAVAILABLERESPONSE"]._serialized_start = 173 - _globals["_TRAINERAVAILABLERESPONSE"]._serialized_end = 218 - _globals["_CHECKPOINTINFO"]._serialized_start = 220 - _globals["_CHECKPOINTINFO"]._serialized_end = 290 - _globals["_STARTTRAININGREQUEST"]._serialized_start = 293 - _globals["_STARTTRAININGREQUEST"]._serialized_end = 1248 - _globals["_STARTTRAININGRESPONSE"]._serialized_start = 1250 - _globals["_STARTTRAININGRESPONSE"]._serialized_end = 1320 - _globals["_TRAININGSTATUSREQUEST"]._serialized_start = 1322 - _globals["_TRAININGSTATUSREQUEST"]._serialized_end = 1366 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_start = 1369 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_end = 1791 - _globals["_STOREFINALMODELREQUEST"]._serialized_start = 1793 - _globals["_STOREFINALMODELREQUEST"]._serialized_end = 1838 - _globals["_STOREFINALMODELRESPONSE"]._serialized_start = 1840 - _globals["_STOREFINALMODELRESPONSE"]._serialized_end = 1904 - _globals["_GETLATESTMODELREQUEST"]._serialized_start = 1906 - _globals["_GETLATESTMODELREQUEST"]._serialized_end = 1950 - _globals["_GETLATESTMODELRESPONSE"]._serialized_start = 1952 - _globals["_GETLATESTMODELRESPONSE"]._serialized_end = 2017 - _globals["_TRAINERSERVER"]._serialized_start = 2020 - _globals["_TRAINERSERVER"]._serialized_end = 2477 + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14trainer_server.proto\x12\x07trainer\"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t\"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t\"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05\"\x19\n\x17TrainerAvailableRequest\"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t\"\xcf\x07\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12\"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x12\x12\n\ngenerative\x18\x1c \x01(\x08\x42\x07\n\x05_seedB\x0c\n\n_tokenizer\"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05\",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen\"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05\",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse\"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse\"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse\"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse\"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse\"\x00\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'trainer_server_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _JSONSTRING._serialized_start=33 + _JSONSTRING._serialized_end=60 + _PYTHONSTRING._serialized_start=62 + _PYTHONSTRING._serialized_end=91 + _DATA._serialized_start=93 + _DATA._serialized_end=144 + _TRAINERAVAILABLEREQUEST._serialized_start=146 + _TRAINERAVAILABLEREQUEST._serialized_end=171 + _TRAINERAVAILABLERESPONSE._serialized_start=173 + _TRAINERAVAILABLERESPONSE._serialized_end=218 + _CHECKPOINTINFO._serialized_start=220 + _CHECKPOINTINFO._serialized_end=290 + _STARTTRAININGREQUEST._serialized_start=293 + _STARTTRAININGREQUEST._serialized_end=1268 + _STARTTRAININGRESPONSE._serialized_start=1270 + _STARTTRAININGRESPONSE._serialized_end=1340 + _TRAININGSTATUSREQUEST._serialized_start=1342 + _TRAININGSTATUSREQUEST._serialized_end=1386 + _TRAININGSTATUSRESPONSE._serialized_start=1389 + _TRAININGSTATUSRESPONSE._serialized_end=1811 + _STOREFINALMODELREQUEST._serialized_start=1813 + _STOREFINALMODELREQUEST._serialized_end=1858 + _STOREFINALMODELRESPONSE._serialized_start=1860 + _STOREFINALMODELRESPONSE._serialized_end=1924 + _GETLATESTMODELREQUEST._serialized_start=1926 + _GETLATESTMODELREQUEST._serialized_end=1970 + _GETLATESTMODELRESPONSE._serialized_start=1972 + _GETLATESTMODELRESPONSE._serialized_end=2037 + _TRAINERSERVER._serialized_start=2040 + _TRAINERSERVER._serialized_end=2497 # @@protoc_insertion_point(module_scope) diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi index 96b20dde5..aeb9fda1a 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi @@ -56,9 +56,7 @@ class Data(google.protobuf.message.Message): dataset_id: builtins.str = ..., num_dataloaders: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["dataset_id", b"dataset_id", "num_dataloaders", b"num_dataloaders"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "num_dataloaders", b"num_dataloaders"]) -> None: ... global___Data = Data @@ -101,19 +99,13 @@ class CheckpointInfo(google.protobuf.message.Message): checkpoint_interval: builtins.int = ..., checkpoint_path: builtins.str = ..., ) -> None: ... - def ClearField( - self, - field_name: typing.Literal[ - "checkpoint_interval", b"checkpoint_interval", "checkpoint_path", b"checkpoint_path" - ], - ) -> None: ... + def ClearField(self, field_name: typing.Literal["checkpoint_interval", b"checkpoint_interval", "checkpoint_path", b"checkpoint_path"]) -> None: ... global___CheckpointInfo = CheckpointInfo @typing.final class StartTrainingRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - PIPELINE_ID_FIELD_NUMBER: builtins.int TRIGGER_ID_FIELD_NUMBER: builtins.int DEVICE_FIELD_NUMBER: builtins.int @@ -141,6 +133,7 @@ class StartTrainingRequest(google.protobuf.message.Message): ENABLE_ACCURATE_GPU_MEASUREMENTS_FIELD_NUMBER: builtins.int RECORD_LOSS_EVERY_FIELD_NUMBER: builtins.int DROP_LAST_BATCH_FIELD_NUMBER: builtins.int + GENERATIVE_FIELD_NUMBER: builtins.int pipeline_id: builtins.int trigger_id: builtins.int device: builtins.str @@ -158,6 +151,7 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool record_loss_every: builtins.int drop_last_batch: builtins.bool + generative: builtins.bool @property def torch_optimizers_configuration(self) -> global___JsonString: ... @property @@ -208,105 +202,14 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool = ..., record_loss_every: builtins.int = ..., drop_last_batch: builtins.bool = ..., + generative: builtins.bool = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_seed", - b"_seed", - "_tokenizer", - b"_tokenizer", - "bytes_parser", - b"bytes_parser", - "checkpoint_info", - b"checkpoint_info", - "criterion_parameters", - b"criterion_parameters", - "data_info", - b"data_info", - "grad_scaler_configuration", - b"grad_scaler_configuration", - "label_transformer", - b"label_transformer", - "lr_scheduler", - b"lr_scheduler", - "seed", - b"seed", - "tokenizer", - b"tokenizer", - "torch_optimizers_configuration", - b"torch_optimizers_configuration", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_seed", - b"_seed", - "_tokenizer", - b"_tokenizer", - "batch_size", - b"batch_size", - "bytes_parser", - b"bytes_parser", - "checkpoint_info", - b"checkpoint_info", - "criterion_parameters", - b"criterion_parameters", - "data_info", - b"data_info", - "device", - b"device", - "drop_last_batch", - b"drop_last_batch", - "enable_accurate_gpu_measurements", - b"enable_accurate_gpu_measurements", - "epochs_per_trigger", - b"epochs_per_trigger", - "grad_scaler_configuration", - b"grad_scaler_configuration", - "label_transformer", - b"label_transformer", - "load_optimizer_state", - b"load_optimizer_state", - "lr_scheduler", - b"lr_scheduler", - "num_prefetched_partitions", - b"num_prefetched_partitions", - "num_samples_to_pass", - b"num_samples_to_pass", - "parallel_prefetch_requests", - b"parallel_prefetch_requests", - "pipeline_id", - b"pipeline_id", - "pretrained_model_id", - b"pretrained_model_id", - "record_loss_every", - b"record_loss_every", - "seed", - b"seed", - "shuffle", - b"shuffle", - "tokenizer", - b"tokenizer", - "torch_criterion", - b"torch_criterion", - "torch_optimizers_configuration", - b"torch_optimizers_configuration", - "transform_list", - b"transform_list", - "trigger_id", - b"trigger_id", - "use_pretrained_model", - b"use_pretrained_model", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "lr_scheduler", b"lr_scheduler", "seed", b"seed", "tokenizer", b"tokenizer", "torch_optimizers_configuration", b"torch_optimizers_configuration"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "batch_size", b"batch_size", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "device", b"device", "drop_last_batch", b"drop_last_batch", "enable_accurate_gpu_measurements", b"enable_accurate_gpu_measurements", "epochs_per_trigger", b"epochs_per_trigger", "generative", b"generative", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "load_optimizer_state", b"load_optimizer_state", "lr_scheduler", b"lr_scheduler", "num_prefetched_partitions", b"num_prefetched_partitions", "num_samples_to_pass", b"num_samples_to_pass", "parallel_prefetch_requests", b"parallel_prefetch_requests", "pipeline_id", b"pipeline_id", "pretrained_model_id", b"pretrained_model_id", "record_loss_every", b"record_loss_every", "seed", b"seed", "shuffle", b"shuffle", "tokenizer", b"tokenizer", "torch_criterion", b"torch_criterion", "torch_optimizers_configuration", b"torch_optimizers_configuration", "transform_list", b"transform_list", "trigger_id", b"trigger_id", "use_pretrained_model", b"use_pretrained_model"]) -> None: ... @typing.overload def WhichOneof(self, oneof_group: typing.Literal["_seed", b"_seed"]) -> typing.Literal["seed"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_tokenizer", b"_tokenizer"] - ) -> typing.Literal["tokenizer"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_tokenizer", b"_tokenizer"]) -> typing.Literal["tokenizer"] | None: ... global___StartTrainingRequest = StartTrainingRequest @@ -324,9 +227,7 @@ class StartTrainingResponse(google.protobuf.message.Message): training_started: builtins.bool = ..., training_id: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["training_id", b"training_id", "training_started", b"training_started"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["training_id", b"training_id", "training_started", b"training_started"]) -> None: ... global___StartTrainingResponse = StartTrainingResponse @@ -387,90 +288,18 @@ class TrainingStatusResponse(google.protobuf.message.Message): downsampling_batches_seen: builtins.int | None = ..., downsampling_samples_seen: builtins.int | None = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_batches_seen", - b"_batches_seen", - "_downsampling_batches_seen", - b"_downsampling_batches_seen", - "_downsampling_samples_seen", - b"_downsampling_samples_seen", - "_exception", - b"_exception", - "_samples_seen", - b"_samples_seen", - "batches_seen", - b"batches_seen", - "downsampling_batches_seen", - b"downsampling_batches_seen", - "downsampling_samples_seen", - b"downsampling_samples_seen", - "exception", - b"exception", - "log", - b"log", - "samples_seen", - b"samples_seen", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_batches_seen", - b"_batches_seen", - "_downsampling_batches_seen", - b"_downsampling_batches_seen", - "_downsampling_samples_seen", - b"_downsampling_samples_seen", - "_exception", - b"_exception", - "_samples_seen", - b"_samples_seen", - "batches_seen", - b"batches_seen", - "blocked", - b"blocked", - "downsampling_batches_seen", - b"downsampling_batches_seen", - "downsampling_samples_seen", - b"downsampling_samples_seen", - "exception", - b"exception", - "is_running", - b"is_running", - "is_training", - b"is_training", - "log", - b"log", - "samples_seen", - b"samples_seen", - "state_available", - b"state_available", - "valid", - b"valid", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_batches_seen", b"_batches_seen", "_downsampling_batches_seen", b"_downsampling_batches_seen", "_downsampling_samples_seen", b"_downsampling_samples_seen", "_exception", b"_exception", "_samples_seen", b"_samples_seen", "batches_seen", b"batches_seen", "downsampling_batches_seen", b"downsampling_batches_seen", "downsampling_samples_seen", b"downsampling_samples_seen", "exception", b"exception", "log", b"log", "samples_seen", b"samples_seen"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_batches_seen", b"_batches_seen", "_downsampling_batches_seen", b"_downsampling_batches_seen", "_downsampling_samples_seen", b"_downsampling_samples_seen", "_exception", b"_exception", "_samples_seen", b"_samples_seen", "batches_seen", b"batches_seen", "blocked", b"blocked", "downsampling_batches_seen", b"downsampling_batches_seen", "downsampling_samples_seen", b"downsampling_samples_seen", "exception", b"exception", "is_running", b"is_running", "is_training", b"is_training", "log", b"log", "samples_seen", b"samples_seen", "state_available", b"state_available", "valid", b"valid"]) -> None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_batches_seen", b"_batches_seen"] - ) -> typing.Literal["batches_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_batches_seen", b"_batches_seen"]) -> typing.Literal["batches_seen"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_downsampling_batches_seen", b"_downsampling_batches_seen"] - ) -> typing.Literal["downsampling_batches_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_downsampling_batches_seen", b"_downsampling_batches_seen"]) -> typing.Literal["downsampling_batches_seen"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_downsampling_samples_seen", b"_downsampling_samples_seen"] - ) -> typing.Literal["downsampling_samples_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_downsampling_samples_seen", b"_downsampling_samples_seen"]) -> typing.Literal["downsampling_samples_seen"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_exception", b"_exception"] - ) -> typing.Literal["exception"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_exception", b"_exception"]) -> typing.Literal["exception"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_samples_seen", b"_samples_seen"] - ) -> typing.Literal["samples_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_samples_seen", b"_samples_seen"]) -> typing.Literal["samples_seen"] | None: ... global___TrainingStatusResponse = TrainingStatusResponse @@ -503,9 +332,7 @@ class StoreFinalModelResponse(google.protobuf.message.Message): valid_state: builtins.bool = ..., model_id: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["model_id", b"model_id", "valid_state", b"valid_state"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["model_id", b"model_id", "valid_state", b"valid_state"]) -> None: ... global___StoreFinalModelResponse = StoreFinalModelResponse @@ -538,8 +365,6 @@ class GetLatestModelResponse(google.protobuf.message.Message): valid_state: builtins.bool = ..., model_path: builtins.str = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["model_path", b"model_path", "valid_state", b"valid_state"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["model_path", b"model_path", "valid_state", b"valid_state"]) -> None: ... global___GetLatestModelResponse = GetLatestModelResponse diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py index a1978b37a..b195f51fc 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py @@ -1,36 +1,8 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" - -import warnings - import grpc -import modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 as trainer__server__pb2 - -GRPC_GENERATED_VERSION = "1.63.0" -GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - warnings.warn( - f"The grpc package installed is at version {GRPC_VERSION}," - + f" but the generated code in trainer_server_pb2_grpc.py depends on" - + f" grpcio>={GRPC_GENERATED_VERSION}." - + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" - + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, - ) +import modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 as trainer__server__pb2 class TrainerServerStub(object): """Missing associated documentation comment in .proto file.""" @@ -42,35 +14,30 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.trainer_available = channel.unary_unary( - "/trainer.TrainerServer/trainer_available", - request_serializer=trainer__server__pb2.TrainerAvailableRequest.SerializeToString, - response_deserializer=trainer__server__pb2.TrainerAvailableResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/trainer_available', + request_serializer=trainer__server__pb2.TrainerAvailableRequest.SerializeToString, + response_deserializer=trainer__server__pb2.TrainerAvailableResponse.FromString, + ) self.start_training = channel.unary_unary( - "/trainer.TrainerServer/start_training", - request_serializer=trainer__server__pb2.StartTrainingRequest.SerializeToString, - response_deserializer=trainer__server__pb2.StartTrainingResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/start_training', + request_serializer=trainer__server__pb2.StartTrainingRequest.SerializeToString, + response_deserializer=trainer__server__pb2.StartTrainingResponse.FromString, + ) self.get_training_status = channel.unary_unary( - "/trainer.TrainerServer/get_training_status", - request_serializer=trainer__server__pb2.TrainingStatusRequest.SerializeToString, - response_deserializer=trainer__server__pb2.TrainingStatusResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/get_training_status', + request_serializer=trainer__server__pb2.TrainingStatusRequest.SerializeToString, + response_deserializer=trainer__server__pb2.TrainingStatusResponse.FromString, + ) self.store_final_model = channel.unary_unary( - "/trainer.TrainerServer/store_final_model", - request_serializer=trainer__server__pb2.StoreFinalModelRequest.SerializeToString, - response_deserializer=trainer__server__pb2.StoreFinalModelResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/store_final_model', + request_serializer=trainer__server__pb2.StoreFinalModelRequest.SerializeToString, + response_deserializer=trainer__server__pb2.StoreFinalModelResponse.FromString, + ) self.get_latest_model = channel.unary_unary( - "/trainer.TrainerServer/get_latest_model", - request_serializer=trainer__server__pb2.GetLatestModelRequest.SerializeToString, - response_deserializer=trainer__server__pb2.GetLatestModelResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/get_latest_model', + request_serializer=trainer__server__pb2.GetLatestModelRequest.SerializeToString, + response_deserializer=trainer__server__pb2.GetLatestModelResponse.FromString, + ) class TrainerServerServicer(object): @@ -79,216 +46,152 @@ class TrainerServerServicer(object): def trainer_available(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def start_training(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def get_training_status(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def store_final_model(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def get_latest_model(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_TrainerServerServicer_to_server(servicer, server): rpc_method_handlers = { - "trainer_available": grpc.unary_unary_rpc_method_handler( - servicer.trainer_available, - request_deserializer=trainer__server__pb2.TrainerAvailableRequest.FromString, - response_serializer=trainer__server__pb2.TrainerAvailableResponse.SerializeToString, - ), - "start_training": grpc.unary_unary_rpc_method_handler( - servicer.start_training, - request_deserializer=trainer__server__pb2.StartTrainingRequest.FromString, - response_serializer=trainer__server__pb2.StartTrainingResponse.SerializeToString, - ), - "get_training_status": grpc.unary_unary_rpc_method_handler( - servicer.get_training_status, - request_deserializer=trainer__server__pb2.TrainingStatusRequest.FromString, - response_serializer=trainer__server__pb2.TrainingStatusResponse.SerializeToString, - ), - "store_final_model": grpc.unary_unary_rpc_method_handler( - servicer.store_final_model, - request_deserializer=trainer__server__pb2.StoreFinalModelRequest.FromString, - response_serializer=trainer__server__pb2.StoreFinalModelResponse.SerializeToString, - ), - "get_latest_model": grpc.unary_unary_rpc_method_handler( - servicer.get_latest_model, - request_deserializer=trainer__server__pb2.GetLatestModelRequest.FromString, - response_serializer=trainer__server__pb2.GetLatestModelResponse.SerializeToString, - ), + 'trainer_available': grpc.unary_unary_rpc_method_handler( + servicer.trainer_available, + request_deserializer=trainer__server__pb2.TrainerAvailableRequest.FromString, + response_serializer=trainer__server__pb2.TrainerAvailableResponse.SerializeToString, + ), + 'start_training': grpc.unary_unary_rpc_method_handler( + servicer.start_training, + request_deserializer=trainer__server__pb2.StartTrainingRequest.FromString, + response_serializer=trainer__server__pb2.StartTrainingResponse.SerializeToString, + ), + 'get_training_status': grpc.unary_unary_rpc_method_handler( + servicer.get_training_status, + request_deserializer=trainer__server__pb2.TrainingStatusRequest.FromString, + response_serializer=trainer__server__pb2.TrainingStatusResponse.SerializeToString, + ), + 'store_final_model': grpc.unary_unary_rpc_method_handler( + servicer.store_final_model, + request_deserializer=trainer__server__pb2.StoreFinalModelRequest.FromString, + response_serializer=trainer__server__pb2.StoreFinalModelResponse.SerializeToString, + ), + 'get_latest_model': grpc.unary_unary_rpc_method_handler( + servicer.get_latest_model, + request_deserializer=trainer__server__pb2.GetLatestModelRequest.FromString, + response_serializer=trainer__server__pb2.GetLatestModelResponse.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler("trainer.TrainerServer", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler( + 'trainer.TrainerServer', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class TrainerServer(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def trainer_available( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def trainer_available(request, target, - "/trainer.TrainerServer/trainer_available", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/trainer_available', trainer__server__pb2.TrainerAvailableRequest.SerializeToString, trainer__server__pb2.TrainerAvailableResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def start_training( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def start_training(request, target, - "/trainer.TrainerServer/start_training", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/start_training', trainer__server__pb2.StartTrainingRequest.SerializeToString, trainer__server__pb2.StartTrainingResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def get_training_status( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def get_training_status(request, target, - "/trainer.TrainerServer/get_training_status", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/get_training_status', trainer__server__pb2.TrainingStatusRequest.SerializeToString, trainer__server__pb2.TrainingStatusResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def store_final_model( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def store_final_model(request, target, - "/trainer.TrainerServer/store_final_model", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/store_final_model', trainer__server__pb2.StoreFinalModelRequest.SerializeToString, trainer__server__pb2.StoreFinalModelResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def get_latest_model( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def get_latest_model(request, target, - "/trainer.TrainerServer/get_latest_model", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/get_latest_model', trainer__server__pb2.GetLatestModelRequest.SerializeToString, trainer__server__pb2.GetLatestModelResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 47acce439..3cc6ea2d7 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -85,7 +85,8 @@ def __init__( self.pipeline_id = training_info.pipeline_id self.training_id = training_info.training_id self.trigger_id = training_info.trigger_id - + self._info("Initializing Pytorch Trainer") + self.generative=training_info.generative self.selector_stub = self.connect_to_selector(training_info.selector_address) if training_info.seed is not None: @@ -217,7 +218,6 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._info("Handled OnBegin Callbacks.") self._log["epochs"] = [] - training_loss: list[float] = [] if self.num_samples_to_pass == 0: epoch_num_generator: Iterable[int] = range(self.epochs_per_trigger) @@ -272,7 +272,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches passed_batches += 1 with GPUMeasurement(self._measure_gpu_ops, "PreprocessBatch", self._device, stopw, resume=True): sample_ids, target, data = self.preprocess_batch(batch, stopw) - + # if self.generative: + #target=data.clone() if retrieve_weights_from_dataloader: # model output is a torch.FloatTensor but weights is a torch.DoubleTensor. # We need to cast to do the dot product @@ -297,15 +298,32 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._assert_data_size(self._batch_size, data, sample_ids, target) with GPUMeasurement(self._measure_gpu_ops, "Forward", self._device, stopw, resume=True): - output = self._model.model(data, sample_ids) + if self.generative: + # Pass input data + output = self._model.model(data) + + else: + # Non-generative task: Pass data, and optionally sample_ids if required + output = self._model.model(data, sample_ids=sample_ids) + with GPUMeasurement(self._measure_gpu_ops, "Loss", self._device, stopw, resume=True): - if weighted_optimization: - # weighted gradient descent - assert weights is not None - loss = torch.dot(self._criterion_nored(output, target), weights / weights.sum()) + if self.generative: + # Shift logits and labels for next-token prediction + output = output[..., :-1, :].contiguous() + target = data[..., 1:].contiguous() + # Calculate loss + loss = self._criterion( + output.view(-1, output.size(-1)), + output.view(-1) + ) else: - loss = self._criterion(output, target) + if weighted_optimization: + # Weighted gradient descent + assert weights is not None + loss = torch.dot(self._criterion_nored(output, target), weights / weights.sum()) + else: + loss = self._criterion(output, target) stopw.start("OnBatchBeforeUpdate", resume=True) for _, callback in self._callbacks.items(): @@ -514,19 +532,21 @@ def preprocess_batch( sample_ids = sample_ids.tolist() elif isinstance(sample_ids, tuple): sample_ids = list(sample_ids) - + print(f"{self.generative}") assert isinstance(sample_ids, list), "Cannot parse result from DataLoader" stopw.stop("PreprocSampleIDs") - - stopw.start("LabelTransform", resume=True) - if self._label_transformer_function is not None: - target = self._label_transformer_function(batch[2]) + if self.generative: + target=None else: - target = batch[2] - stopw.stop("LabelTransform") + stopw.start("LabelTransform", resume=True) + if self._label_transformer_function is not None: + target = self._label_transformer_function(batch[2]) + else: + target = batch[2] + stopw.stop("LabelTransform") - with GPUMeasurement(self._measure_gpu_ops, "MoveLabelToGPU", self._device, stopw, resume=True): - target = target.to(self._device) + with GPUMeasurement(self._measure_gpu_ops, "MoveLabelToGPU", self._device, stopw, resume=True): + target = target.to(self._device) with GPUMeasurement(self._measure_gpu_ops, "MoveDataToGPU", self._device, stopw, resume=True): data: torch.Tensor | dict @@ -573,9 +593,7 @@ def downsample_batch( big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() embeddings = self.get_embeddings_if_recorded() self._downsampler.inform_samples(sample_ids, data, big_batch_output, target, embeddings) - self.end_embedding_recorder_if_needed() - # TODO(#218) Persist information on the sample IDs/weights when downsampling is performed selected_indexes, weights = self._downsampler.select_points() selected_data, selected_target = get_tensors_subset(selected_indexes, data, target, sample_ids) @@ -898,32 +916,36 @@ def _iterate_dataloader_and_compute_scores( Args: dataloader: torch.dataloader to get the data previous_batch_number: The batch number returned from the last call to this method. Useful when this - function is called several times to keep track of previous invocations (ex label by label dataloader). We - need to have a total to correctly update the queue and show the progress in the supervisor counter. - previous_number_of_samples: number of samples processed before calling this function. See above for the use. + function is called several times to keep track of previous invocations (e.g., label-by-label dataloader). + previous_number_of_samples: Number of samples processed before calling this function. See above for usage. Returns: Updated number of batches and samples """ number_of_samples = previous_number_of_samples batch_number = previous_batch_number + for batch in dataloader: self.update_queue("DOWNSAMPLING", batch_number, number_of_samples, training_active=False) batch_number += 1 sample_ids, target, data = self.preprocess_batch(batch) + # Handle cases where target is None for generative tasks + if self.generative and target is None: + target = torch.Tensor() number_of_samples += len(sample_ids) - no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode() context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr with context_manager: with torch.autocast(self._device_type, enabled=self._amp): - # compute the scores and accumulate them - model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() + model_output = ( + self._model.model(data) if self._downsampler.forward_required else torch.Tensor() + ) embeddings = self.get_embeddings_if_recorded() + # Inform the downsampler self._downsampler.inform_samples(sample_ids, data, model_output, target, embeddings) - return batch_number, number_of_samples + # ---------------------------------------------------- Logging --------------------------------------------------- # def _info(self, msg: str) -> None: diff --git a/modyn/trainer_server/internal/utils/training_info.py b/modyn/trainer_server/internal/utils/training_info.py index 07a246b35..ac9df472a 100644 --- a/modyn/trainer_server/internal/utils/training_info.py +++ b/modyn/trainer_server/internal/utils/training_info.py @@ -57,7 +57,8 @@ def __init__( self.shuffle = request.shuffle self.enable_accurate_gpu_measurements = request.enable_accurate_gpu_measurements - + self.generative=request.generative + print(f"generetive:{self.generative}") assert ( self.pretrained_model_path or not self.use_pretrained_model ), "Inconsistent pretrained model configuration"