diff --git a/benchmark/mnist/mnist.yaml b/benchmark/mnist/mnist.yaml index ea0765945..3d065c22a 100644 --- a/benchmark/mnist/mnist.yaml +++ b/benchmark/mnist/mnist.yaml @@ -12,6 +12,7 @@ model_storage: training: gpus: 1 device: "cuda:0" + generative: False dataloader_workers: 2 use_previous_model: True initial_model: random diff --git a/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml b/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml index a4c126594..ee6195f91 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml @@ -17,6 +17,7 @@ training: initial_model: random batch_size: 128 shuffle: True + generative: False optimizers: - name: "default" algorithm: "SGD" diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml index 7173bff24..48789b135 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml @@ -13,6 +13,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 96 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml index 4ce54ab12..8e60f7d00 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml @@ -13,6 +13,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 64 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml index e1d1f1aec..336ea8e5a 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml @@ -14,6 +14,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 64 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml b/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml index e376b47ab..a24598377 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml @@ -13,6 +13,8 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False + use_previous_model: True initial_model: random batch_size: 64 diff --git a/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml b/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml index 27eda1029..a93557042 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/huffpost.yaml @@ -12,6 +12,7 @@ model_storage: training: gpus: 1 device: "cuda:0" + generative: False dataloader_workers: 2 use_previous_model: True initial_model: random diff --git a/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml b/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml index cdb4d4c64..8f441c08e 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/yearbook.yaml @@ -14,6 +14,7 @@ training: gpus: 1 device: "cuda:0" dataloader_workers: 2 + generative: False use_previous_model: True initial_model: random batch_size: 64 diff --git a/environment.yml b/environment.yml index ca9f223b0..47563adca 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - psycopg2 - sqlalchemy>=2.0 - pyaml - - pydantic + - pydantic==2.9.2 - numpy==1.26.* - pandas - bitstring @@ -43,11 +43,10 @@ dependencies: - nltk - pytorch::pytorch=2.2.1 - pytorch::torchvision - - pytorch::cpuonly # comment out if commenting in lines below for CUDA -# - pytorch::pytorch-cuda=12.1 -# - nvidia::cuda-libraries-dev=12.1.* -# - nvidia::cuda-nvcc=12.1.* -# - nvidia::cuda-nvtx=12.1.* -# - nvidia::cuda-cupti=12.1.* -# - nvidia::cuda-cudart-dev=12.1.* -# - nvidia::cuda-profiler-api=12.1.* + - pytorch::pytorch-cuda=12.1 + - nvidia::cuda-libraries-dev=12.1.* + - nvidia::cuda-nvcc=12.1.* + - nvidia::cuda-nvtx=12.1.* + - nvidia::cuda-cupti=12.1.* + - nvidia::cuda-cudart-dev=12.1.* + - nvidia::cuda-profiler-api==12.1.* diff --git a/integrationtests/config/dummy.yaml b/integrationtests/config/dummy.yaml index f37fe9cae..b9fdd5b6c 100644 --- a/integrationtests/config/dummy.yaml +++ b/integrationtests/config/dummy.yaml @@ -12,6 +12,7 @@ model_storage: training: gpus: 1 device: "cpu" + generative: False dataloader_workers: 1 use_previous_model: True initial_model: random diff --git a/integrationtests/config/rho_loss.yaml b/integrationtests/config/rho_loss.yaml index a30748c44..f2f89e76d 100644 --- a/integrationtests/config/rho_loss.yaml +++ b/integrationtests/config/rho_loss.yaml @@ -13,6 +13,7 @@ training: gpus: 1 device: "cpu" dataloader_workers: 2 + generative: False use_previous_model: False initial_model: random batch_size: 4 @@ -60,6 +61,7 @@ selection_strategy: il_model_config: num_classes: 10 device: "cpu" + generative: False dataloader_workers: 1 use_previous_model: False batch_size: 2 @@ -75,4 +77,4 @@ selection_strategy: lr: 0.1 momentum: 0.001 optimization_criterion: - name: "CrossEntropyLoss" + name: "CrossEntropyLoss" \ No newline at end of file diff --git a/modyn/common/grpc/grpc_helpers.py b/modyn/common/grpc/grpc_helpers.py index f115d0c22..f8a3f2a24 100644 --- a/modyn/common/grpc/grpc_helpers.py +++ b/modyn/common/grpc/grpc_helpers.py @@ -251,6 +251,7 @@ def prepare_start_training_request( enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements, record_loss_every=training_config.record_loss_every, drop_last_batch=training_config.drop_last_batch, + generative=training_config.generative, ) def start_training( diff --git a/modyn/config/examples/modyn_config.yaml b/modyn/config/examples/modyn_config.yaml index 42ec1588e..6ed900978 100644 --- a/modyn/config/examples/modyn_config.yaml +++ b/modyn/config/examples/modyn_config.yaml @@ -278,7 +278,7 @@ selector: local_storage_directory: "/tmp/local_storage" local_storage_max_samples_in_file: 1000000 cleanup_storage_directories_after_shutdown: true - ignore_existing_trigger_samples: false + ignore_existing_trigger_samples: true trainer_server: hostname: "trainer_server" diff --git a/modyn/config/schema/pipeline/training/config.py b/modyn/config/schema/pipeline/training/config.py index b3f673553..98859587d 100644 --- a/modyn/config/schema/pipeline/training/config.py +++ b/modyn/config/schema/pipeline/training/config.py @@ -119,6 +119,11 @@ class TrainingConfig(ModynBaseModel): "we start with random weights. If initial_model is 'pretrained', cannot be False." ) ) + generative: bool = Field(False, + description=( + "If True then, then the training pipeline goes into the generative branch, data is sampled without expecting labels." + ) + ) seed: int | None = Field( None, description=( diff --git a/modyn/config/schema/system/config.py b/modyn/config/schema/system/config.py index 881f0fb01..4827b6e9b 100644 --- a/modyn/config/schema/system/config.py +++ b/modyn/config/schema/system/config.py @@ -255,7 +255,7 @@ class SelectorConfig(HostnamePortMixin): ), ) ignore_existing_trigger_samples: bool = Field( - False, + True, description=( "Whether to ignore existing trigger samples when starting the selector. If set to false, the trigger " "sample directory has to be empty upon startup. May lead to unexpected behaviour if set to true and the " diff --git a/modyn/models/GPT2/GPT2.py b/modyn/models/GPT2/GPT2.py new file mode 100644 index 000000000..2ebf5cd49 --- /dev/null +++ b/modyn/models/GPT2/GPT2.py @@ -0,0 +1,365 @@ +#plimport pytorch_lightning as pl + +from transformers import ( + Adafactor, + GPT2LMHeadModel, + GPT2Tokenizer, +) + +import torch +#from Datasets import CustomDataset, Pretrain_Chunks +from torch.utils.data import RandomSampler +from torch.utils.data import DataLoader, ConcatDataset +from collections import Counter +from typing import Optional, Tuple, Dict, Union, Any, Callable +from typing import Any, Union,List +import re +import string +#from deepspeed.runtime.lr_schedules import WarmupDecayLR +#import deepspeed +import math +import os +import csv +from modyn.models.GPT2.GPT2_Model_LoRA import GPT2LMHeadModel #as GPT2_Lora +#from modyn.models.GPT2.RecAdam import RecAdam +from modyn.models.coreset_methods_support import CoresetSupportingModule + + +class GPT2: + # pylint: disable-next=unused-argument + def __init__(self, hparams: Any, device: str, amp: bool) -> None: + self.model = GPT2Modyn(hparams) + self.model.to(device) + + +# the following class is adapted from +# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py + +class GPT2Modyn(CoresetSupportingModule): + def __init__(self, hparams: Any) -> None: + super(GPT2Modyn, self).__init__() + # self.save_hyperparameters(hparams) + self.unchanged_loss: float = 0.0 + self.changed_loss: float = 0.0 + self.invariant_loss: float = 0.0 + self.unchanged: int = 0 + self.changed: int = 0 + self.invariant: int = 0 + self.validation: int = 0 + self.validation_loss: float = 0.0 + + self.mix_ratio: float = 1.0 + self.mix_decay: float = 0.7 + self.epoch: int = 0 + + self.model = GPT2LMHeadModel.from_pretrained("gpt2-large") # hparams.model_name_or_path + + def freeze_params(self, model: torch.nn.Module) -> None: + for par in model.parameters(): + par.requires_grad = False + + def normalize_answer(self, s: str) -> str: + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text: str) -> str: + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text: str) -> str: + return " ".join(text.split()) + + def remove_punc(text: str) -> str: + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text: str) -> str: + return text.lower() + + def rid_of_specials(text: str) -> str: + text = text.replace("", "") + text = text.replace("", "") + return text + + return rid_of_specials(white_space_fix(remove_articles(remove_punc(lower(s))))) + + def exact_match_score(self, prediction: str, ground_truth: str) -> int: + return int(self.normalize_answer(prediction) == self.normalize_answer(ground_truth)) + + def _f1_score(self, prediction: str, ground_truth: str) -> float: + prediction_tokens = self.normalize_answer(prediction).split() + ground_truth_tokens = self.normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0.0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + return (2 * precision * recall) / (precision + recall) + + def calculate_scores(self, predictions: list[str], ground_truths: list[str]) -> tuple[float, float]: + em_score: float = 0.0 + f1_score: float = 0.0 + + for i in range(len(predictions)): + ground_truth = ground_truths[i] + prediction = predictions[i] + em_score += self.exact_match_score(prediction, ground_truth) + f1_score += self._f1_score(prediction, ground_truth) + + em_score /= len(predictions) + f1_score /= len(predictions) + return em_score * 100, f1_score * 100 + + def lmap(self, f: Callable, x: list[Any]) -> list[Any]: + """list(map(f, x))""" + return list(map(f, x)) + + def is_logger(self) -> bool: + return self.trainer.global_rank <= 0 + + def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + lm_labels: Optional[torch.Tensor] = None) -> Any: + return self.model( + input_ids, + attention_mask=attention_mask, + labels=lm_labels, + ) + + def get_last_layer(self) -> torch.nn.Module: + return self.model.lm_head + + """ + def _step(self, batch): + lm_labels = batch["target_ids"] + lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100 + outputs = self( + input_ids=batch["source_ids"], + attention_mask=batch["source_mask"], + lm_labels=lm_labels, + ) + + loss = outputs[0] + return loss + + def valid_step(self, batch): + lm_labels = batch["label_ids"].clone().detach() + lm_labels[source_nonprompt_mask == 0] = -100 + outputs = self( + input_ids=batch["label_ids"], + attention_mask=batch["label_mask"], + lm_labels=lm_labels, + ) + + loss = outputs[0] + print(loss) + return loss + + + + def ids_to_clean_text(self, generated_ids): + gen_text = self.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return self.lmap(str.strip, gen_text) + + def _generative_step_finetune(self, batch, batch_idx): + loss = self.valid_step(batch) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.validation +=1 + self.validation_loss += loss + average_loss = self.validation_loss / self.validation + ppl = torch.exp(average_loss) + self.log('validation_ppl', ppl, prog_bar=True, logger=True) + + source = self.ids_to_clean_text(batch["source_ids"]) + generated_ids = self.model.generate( + batch["source_ids"], + attention_mask=batch["source_mask"], + use_cache=True, + max_length=self.hparams.max_input_length + 5, + num_beams=2, + early_stopping=True + ) + targets = self.ids_to_clean_text(batch["target_ids"]) + + generated_ids = torch.transpose(torch.transpose(generated_ids,0,1)[self.hparams.max_input_length:],0,1) + preds = self.ids_to_clean_text(generated_ids) + clean_preds = [] + for text in preds: + if "." in text: + clean_preds.append(text[:text.find(".")+1]) + else: + clean_preds.append(text) + print("clean_preds",clean_preds) + print("targets",targets) + + em_score, f1_score = self.calculate_scores(clean_preds, targets) + print(em_score, f1_score, ppl) + self.log('EM score', em_score, prog_bar=True, logger=True) + self.log('F1 score', f1_score, prog_bar=True, logger=True) + + + def _generative_step(self, batch, batch_idx, dataloader_idx=-1): + loss = self._step(batch) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + if dataloader_idx == 0: + self.unchanged +=1 + self.unchanged_loss += loss + average_loss = self.unchanged_loss / self.unchanged + ppl = torch.exp(average_loss) + self.log('UnC_ppl', ppl, prog_bar=True, logger=True) + print('UnC_ppl', ppl) + elif dataloader_idx == 1: + self.changed +=1 + self.changed_loss += loss + average_loss = self.changed_loss / self.changed + ppl = torch.exp(average_loss) + self.log('C_ppl', ppl, prog_bar=True, logger=True) + print('C_ppl', ppl) + else: + self.invariant +=1 + self.invariant_loss += loss + average_loss = self.invariant_loss / self.invariant + ppl = torch.exp(average_loss) + self.log('IL_ppl', ppl, prog_bar=True, logger=True) + print('IL_ppl', ppl) + + def training_step(self, batch, batch_idx): + loss = self._step(batch) + self.log("loss", loss) + return loss + + def on_train_epoch_end(self): + if self.hparams.mode=='pretrain_brute': + self.dataset_index+=1 + if self.dataset_index==self.hparams.num_files: + self.global_epoch+=1 + self.log('global_epoch', self.global_epoch, prog_bar=True, logger=True) + self.dataset_index=0 + self.train_dataloader() + if self.hparams.method=='mixreview': + train_set = self.train_dataloader().dataset + self.epoch+=1 + + def validation_step(self, batch, batch_idx, dataloader_idx=-1): + if self.hparams.mode == 'finetune': + return self._generative_step_finetune(batch, batch_idx, dataloader_idx) + return self._generative_step(batch, batch_idx, dataloader_idx) + + def configure_optimizers(self, train_len=None): + "Prepare optimizer and schedule (linear warmup and decay)" + if self.hparams.method=='recadam': + no_decay = ["bias", "LayerNorm.weight"] + model_type = 'gpt2' + recadam_anneal_w = 1.0 + recadam_anneal_fun = 'sigmoid' + recadam_anneal_k = 0.5 + recadam_anneal_t0 = 250 + recadam_pretrain_cof = 5000.0 + new_model = self.model + pretrained_model = self.pretrained_model + optimizer_grouped_parameters = [ + { + "params": [p for n, p in new_model.named_parameters() if + not any(nd in n for nd in no_decay) and model_type in n], + "weight_decay": self.hparams.weight_decay, + "anneal_w": recadam_anneal_w, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + not any(nd in p_n for nd in no_decay) and model_type in p_n] + }, + { + "params": [p for n, p in new_model.named_parameters() if + not any(nd in n for nd in no_decay) and model_type not in n], + "weight_decay": self.hparams.weight_decay, + "anneal_w": 0.0, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + not any(nd in p_n for nd in no_decay) and model_type not in p_n] + }, + { + "params": [p for n, p in new_model.named_parameters() if + any(nd in n for nd in no_decay) and model_type in n], + "weight_decay": 0.0, + "anneal_w": recadam_anneal_w, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + any(nd in p_n for nd in no_decay) and model_type in p_n] + }, + { + "params": [p for n, p in new_model.named_parameters() if + any(nd in n for nd in no_decay) and model_type not in n], + "weight_decay": 0.0, + "anneal_w": 0.0, + "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if + any(nd in p_n for nd in no_decay) and model_type not in p_n] + } + ] + optimizer = RecAdam(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon, + anneal_fun=recadam_anneal_fun, anneal_k=recadam_anneal_k, + anneal_t0=recadam_anneal_t0, pretrain_cof=recadam_pretrain_cof) + else: + model = self.model + + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + if self.hparams.accelerator is not None: + optimizer = deepspeed.ops.adam.FusedAdam(optimizer_grouped_parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) + else: + optimizer = Adafactor(optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False) + + if self.hparams.use_lr_scheduling: + if self.hparams.len_data==None: + len_data = len(self.train_dataloader()) + else: + len_data = int(self.hparams.len_data // self.hparams.train_batch_size) + denomniator = (self.hparams.n_gpu * self.hparams.gradient_accumulation_steps) + + steps_per_epoch = ( len_data // denomniator ) + 1 + schedule_scale_factor = 8 + total_num_steps = ( steps_per_epoch * self.hparams.num_train_epochs ) * self.hparams.num_files * schedule_scale_factor + + print(f'total number of steps : {total_num_steps}') + scheduler = WarmupDecayLR(optimizer, total_num_steps = total_num_steps ,warmup_max_lr = self.hparams.learning_rate, warmup_num_steps = int(total_num_steps * 0.1)) + return [optimizer], [{"scheduler": scheduler, "interval": "step", "name": "learning rate"}] + else: + return [optimizer] + + def train_dataloader(self): + if self.hparams.mode=='pretrain_brute': + train_dataset = Pretrain_Chunks(dataset_name=self.dataset_lst[self.dataset_index],tokenizer=self.tokenizer, input_length=self.hparams.max_input_length, output_length=self.hparams.max_output_length, args=self.hparams) + else: + train_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="train", args=self.hparams) + if self.hparams.method=='mixreview': + #mix_len = int(len(train_dataset) * self.mix_ratio * (self.mix_decay ** self.epoch)) + mix_len = int(len(train_dataset)) + pretrain_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="pretrain", args=self.hparams, length=mix_len) + mixed_dataset = ConcatDataset([train_dataset,pretrain_dataset]) + print("mix len is ", mix_len) + sampler = RandomSampler(mixed_dataset) + dataloader = DataLoader(mixed_dataset, sampler = sampler, batch_size=self.hparams.train_batch_size, drop_last=True, num_workers=self.hparams.num_workers) + print("dataset length is ", len(dataloader.dataset)) + else: + sampler = RandomSampler(train_dataset) + dataloader = DataLoader(train_dataset, sampler=sampler, batch_size=self.hparams.train_batch_size, drop_last=True, num_workers=self.hparams.num_workers) + return dataloader + + def val_dataloader(self): + validation_dataset_unchanged = self.get_dataset(tokenizer=self.tokenizer, type_path="validation", args=self.hparams, lama_type='unchanged') + validation_dataset_changed = self.get_dataset(tokenizer=self.tokenizer, type_path="validation", args=self.hparams, lama_type='changed') + return [DataLoader(validation_dataset_unchanged, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False), + DataLoader(validation_dataset_changed, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False), + ] + def test_dataloader(self): + test_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="test", args=self.hparams) + + return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False) + """ \ No newline at end of file diff --git a/modyn/models/GPT2/GPT2_Model_LoRA.py b/modyn/models/GPT2/GPT2_Model_LoRA.py new file mode 100644 index 000000000..d474d1474 --- /dev/null +++ b/modyn/models/GPT2/GPT2_Model_LoRA.py @@ -0,0 +1,1115 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Union, Any, Callable +from typing import Any, Union,List +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + + +if version.parse(torch.__version__) >= version.parse("1.6"): + is_amp_available = True + from torch.cuda.amp import autocast +else: + is_amp_available = False + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + Conv1D, + PreTrainedModel, + SequenceSummary, + find_pruneable_heads_and_indices, + prune_conv1d_layer, +) +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers import GPT2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt2" +_CONFIG_FOR_DOC = "GPT2Config" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +] + + +def load_tf_weights_in_gpt2( + model: torch.nn.Module, + config: Any, + gpt2_checkpoint_path: str +) -> torch.nn.Module: + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + +class LoRALayer(nn.Module): + def __init__( + self, + n_in: int, + n_out: Optional[int] = None, + adapter_dim: int = 16, + adapter_alpha: int = 32 + ) -> None: + super(LoRALayer, self).__init__() + if not n_out: + n_out = n_in + self.adapter_dim = adapter_dim + self.adapter_alpha = adapter_alpha + self.adapter_proj_1 = nn.Linear(n_in, adapter_dim, bias=False) + nn.init.normal_(self.adapter_proj_1.weight, std=0.02) + self.adapter_proj_2 = nn.Linear(adapter_dim, n_out, bias=False) + self.adapter_proj_2.weight.data.zero_() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + scale_factor = self.adapter_dim / self.adapter_alpha + result = torch.matmul(x, self.adapter_proj_1.weight.type_as(x).T) + return torch.matmul(result, self.adapter_proj_2.weight.type_as(x).T) * scale_factor + +class GPT2Attention(nn.Module): + def __init__( + self, + config: Any, # Replace `Any` with the specific config class if known + is_cross_attention: bool = False, + layer_idx: int = -1 +) -> None: + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.lora_attn_dim = 4 + self.lora_attn_alpha = 16 + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.q_lora = LoRALayer(self.embed_dim, adapter_dim=self.lora_attn_dim, adapter_alpha=self.lora_attn_alpha) + self.v_lora = LoRALayer(self.embed_dim, adapter_dim=self.lora_attn_dim, adapter_alpha=self.lora_attn_alpha) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads: set[int] = set() + + + def prune_heads(self, heads: List[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None + ) -> Any: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None + ) -> Any: + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + if is_amp_available: + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + else: + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads( + self, + tensor: torch.Tensor, + num_heads: int, + attn_head_size: int + ) -> torch.Tensor: + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(*new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads( + self, + tensor: torch.Tensor, + num_heads: int, + attn_head_size: int + ) -> torch.Tensor: + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False + )-> Any: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_delta = self.q_lora(hidden_states) + value_delta = self.v_lora(encoder_hidden_states) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_delta = self.q_lora(hidden_states) + value_delta = self.v_lora(hidden_states) + + query = query.contiguous() + query_delta + value = value.contiguous() + value_delta + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + """ + if use_cache is True: + present = (key, value) + else: + present = None + """ + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output,None )# type: ignore + if output_attentions: + outputs += (attn_weights,) # type: ignore + + return outputs # a, present, (attentions) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size: int, config: Any) -> None: + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + +class GPT2Block(nn.Module): + def __init__(self, config: Any, layer_idx: int = -1) -> None: + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1: nn.LayerNorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn: nn.Module = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2: nn.LayerNorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention: nn.Module = GPT2Attention(config, is_cross_attention=True) + self.ln_cross_attn: nn.LayerNorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp: nn.Module = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + + def __init__(self, *inputs: Any, **kwargs: Any) -> None: + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + def _set_gradient_checkpointing( + self, + module: Union[torch.nn.Module, Any], # Replace `Any` with the actual type of GPT2Model if available + value: bool = False + ) -> None: + """Enable or disable gradient checkpointing for the given module.""" + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + + +GPT2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be + passed as `input_ids`. + Indices can be obtained using [`GPT2Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which + have their past given to this model should not be passed as `input_ids` as they have already been + computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up + decoding (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + Example: + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained('gpt2-xl') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]} + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + Example: + ```python + # On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained('gpt2-large') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) + +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing: List[str] = ["attn.masked_bias"] + + def __init__(self, config: Any) -> None: + super().__init__(config) + + self.embed_dim: int = config.hidden_size + + self.wte: nn.Embedding = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe: nn.Embedding = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop: nn.Dropout = nn.Dropout(config.embd_pdrop) + self.h: nn.ModuleList = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f: nn.LayerNorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel: bool = False + self.device_map: Optional[Dict[int, List[int]]] = None + self.gradient_checkpointing: bool = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map: Optional[Dict[int, List[int]]] = None) -> None: + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device: str = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device: str = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self) -> None: + self.model_parallel = False + self.device_map = None + self.first_device: str = "cpu"# type: ignore[no-redef] + self.last_device: str = "cpu"# type: ignore[no-redef] + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self) -> nn.Embedding: + return self.wte + + def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[ + Tuple[ + torch.Tensor, + Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]], + Optional[Tuple[torch.Tensor, ...]], + Optional[Tuple[torch.Tensor, ...]], + Optional[Tuple[torch.Tensor, ...]], + ], + BaseModelOutputWithPastAndCrossAttentions, + ]: + # The implementation remains unchanged + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device # type: ignore[union-attr] + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) # type: ignore[list-item] + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents: Tuple[torch.Tensor, ...] = () + all_self_attentions: Tuple[torch.Tensor, ...] = () + all_cross_attentions: Tuple[torch.Tensor, ...] = () + all_hidden_states: Tuple[torch.Tensor, ...] = () + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module: Callable[..., Any]) -> Callable[..., Any]: + def custom_forward(*inputs: Any) -> Any: + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): # type: ignore[union-attr] + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) + + +class GPT2LMHeadModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing: List[str] = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config: Any) -> None: + super().__init__(config) + self.transformer: GPT2Model = GPT2Model(config) + self.lm_head: nn.Linear = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel: bool = False + self.device_map: Optional[Dict[int, List[int]]] = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map: Optional[Dict[int, List[int]]] = None) -> None: + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self) -> None: + self.transformer.deparallelize() + self.transformer = self.transformer.to("cuda") + self.lm_head = self.lm_head.to("cuda") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self) -> nn.Linear: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, + input_ids: torch.Tensor, + past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + token_type_ids = kwargs.get("token_type_ids", None) + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[ + Tuple[torch.Tensor, ...], + CausalLMOutputWithCrossAttentions, + ]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, ...], ...], + beam_idx: torch.Tensor, + ) -> Tuple[Tuple[torch.Tensor, ...], ...]: + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + diff --git a/modyn/models/GPT2/_init_.py b/modyn/models/GPT2/_init_.py new file mode 100644 index 000000000..f5987b52c --- /dev/null +++ b/modyn/models/GPT2/_init_.py @@ -0,0 +1,5 @@ +import os + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/__init__.py b/modyn/models/__init__.py index 8d92f68e8..668e9ee6c 100644 --- a/modyn/models/__init__.py +++ b/modyn/models/__init__.py @@ -12,6 +12,7 @@ from .rho_loss_twin_model.rho_loss_twin_model import RHOLOSSTwinModel # noqa: F401 from .smallyearbooknet.smallyearbooknet import SmallYearbookNet # noqa: F401 from .yearbooknet.yearbooknet import YearbookNet # noqa: F401 +from .GPT2.GPT2 import GPT2 #noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") diff --git a/modyn/models/tokenizers/__init__.py b/modyn/models/tokenizers/__init__.py index 93e8d9c8f..47884e93e 100644 --- a/modyn/models/tokenizers/__init__.py +++ b/modyn/models/tokenizers/__init__.py @@ -3,7 +3,7 @@ import os from .distill_bert_tokenizer import DistilBertTokenizerTransform # noqa: F401 - +from .gpt2_tokenizer import GPT2TokenizerTransform # noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") __all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/tokenizers/gpt2_tokenizer.py b/modyn/models/tokenizers/gpt2_tokenizer.py new file mode 100644 index 000000000..204abce03 --- /dev/null +++ b/modyn/models/tokenizers/gpt2_tokenizer.py @@ -0,0 +1,28 @@ +import torch +from transformers import GPT2Tokenizer + +class GPT2TokenizerTransform: + def __init__(self, max_token_length: int = 300): + # Load the GPT-2 tokenizer + self.max_token_length = max_token_length + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") + # Set the pad token to the eos token to avoid padding errors + self.tokenizer.add_special_tokens({ + "eos_token": "", + "bos_token": "", + "unk_token": "", + "pad_token": "", + "mask_token": "" + }) + self.tokenizer.padding_side = "left" + + def __call__(self, sample: str) -> torch.Tensor: + # Make the class callable to use it as Torch Transform + tokens = self.tokenizer( + sample, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" + ) + # Create a tensor whose first dimension is the input_ids and the second is the attention_mask + data = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) + data = torch.squeeze(data, dim=0) # First shape dim is always 1, since the input is just one string + return data + \ No newline at end of file diff --git a/modyn/protos/storage.proto b/modyn/protos/storage.proto index 4b6fc9dde..37fa4bff8 100644 --- a/modyn/protos/storage.proto +++ b/modyn/protos/storage.proto @@ -4,6 +4,7 @@ package modyn.storage; service Storage { rpc Get(GetRequest) returns (stream GetResponse) {} + rpc GetNL(GetRequest) returns (stream GetResponseNoLabels) {} rpc GetNewDataSince(GetNewDataSinceRequest) returns (stream GetNewDataSinceResponse) {} rpc GetDataInInterval(GetDataInIntervalRequest) @@ -25,6 +26,7 @@ service Storage { message GetRequest { string dataset_id = 1; repeated int64 keys = 2; + bool include_labels = 3; // Added this line } message GetResponse { @@ -33,6 +35,13 @@ message GetResponse { repeated int64 labels = 3; } + +message GetResponseNoLabels { + repeated bytes samples = 1; + repeated int64 keys = 2; + +} + // https://github.com/grpc/grpc/issues/15937 message GetCurrentTimestampRequest {} diff --git a/modyn/protos/trainer_server.proto b/modyn/protos/trainer_server.proto index 0c7343c70..6eda308c4 100644 --- a/modyn/protos/trainer_server.proto +++ b/modyn/protos/trainer_server.proto @@ -59,6 +59,7 @@ message StartTrainingRequest { bool enable_accurate_gpu_measurements = 25; int64 record_loss_every = 26; bool drop_last_batch = 27; + bool generative=28; } message StartTrainingResponse { diff --git a/modyn/storage/include/internal/grpc/draft.hpp b/modyn/storage/include/internal/grpc/draft.hpp new file mode 100644 index 000000000..f233dd948 --- /dev/null +++ b/modyn/storage/include/internal/grpc/draft.hpp @@ -0,0 +1,152 @@ +//"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +// Adapted function to send sample data without labels +template > +static void send_sample_data_for_keys_and_file_NL( // NOLINT(readability-function-cognitive-complexity) //TODO Adaptar esto + WriterT* writer, std::mutex& writer_mutex, const std::vector& sample_keys, + const DatasetData& dataset_data, soci::session& session, int64_t /*sample_batch_size*/) { + + // Note that we currently ignore the sample batch size here, under the assumption that users do not request more + // keys than this + try { + const uint64_t num_keys = sample_keys.size(); + + if (num_keys == 0) { + SPDLOG_ERROR("num_keys is 0, this should not have happened. Exiting send_sample_data_for_keys_and_file_NL"); + return; + } + + // Removed labels-related vectors + // std::vector sample_labels(num_keys); + std::vector sample_indices(num_keys); + std::vector sample_fileids(num_keys); + + const std::string sample_query = fmt::format( + "SELECT sample_index, file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN ({}) ORDER BY file_id", + fmt::join(sample_keys, ",")); + session << sample_query, + soci::into(sample_indices), + soci::into(sample_fileids), + soci::use(dataset_data.dataset_id); + + if (sample_fileids.size() != num_keys) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException( + fmt::format("Got back {} samples from DB, while asking for {} keys. You might have asked for duplicate " + "keys, which is not supported.", + sample_fileids.size(), num_keys)); + } + + int64_t current_file_id = sample_fileids.at(0); + uint64_t current_file_start_idx = 0; + std::string current_file_path; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}, current_file_id = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, current_file_id, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config); + auto filesystem_wrapper = + get_filesystem_wrapper(static_cast(dataset_data.filesystem_wrapper_type)); + + auto file_wrapper = + get_file_wrapper(current_file_path, static_cast(dataset_data.file_wrapper_type), + file_wrapper_config_node, filesystem_wrapper); + + for (uint64_t sample_idx = 0; sample_idx < num_keys; ++sample_idx) { + const int64_t& sample_fileid = sample_fileids.at(sample_idx); + + if (sample_fileid != current_file_id) { + // 1. Prepare response without labels + const std::vector file_indexes( + sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.begin() + static_cast(sample_idx)); + std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + data.clear(); + data.shrink_to_fit(); + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.begin() + static_cast(sample_idx)); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.begin() + static_cast(sample_idx)); + + // 2. Send response + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + + // 3. Update state + current_file_id = sample_fileid; + current_file_path = ""; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + const int64_t& previous_fid = sample_fileids.at(sample_idx - 1); + SPDLOG_ERROR( + fmt::format("num_keys = {}, sample_idx = {}, previous_fid = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, sample_idx, previous_fid, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + file_wrapper->set_file_path(current_file_path); + current_file_start_idx = sample_idx; + } + } + + // Send leftovers without labels + const std::vector file_indexes(sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.end()); + const std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.end()); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.end()); + + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in send_sample_data_for_keys_and_file_NL: {}", e.what()); + SPDLOG_ERROR("Propagating error up the call chain to handle gRPC calls."); + throw; + } +} diff --git a/modyn/storage/include/internal/grpc/storage_service_impl.hpp b/modyn/storage/include/internal/grpc/storage_service_impl.hpp index 875bf2e53..0fc8189ee 100644 --- a/modyn/storage/include/internal/grpc/storage_service_impl.hpp +++ b/modyn/storage/include/internal/grpc/storage_service_impl.hpp @@ -74,6 +74,8 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { Status Get(ServerContext* context, const modyn::storage::GetRequest* request, ServerWriter* writer) override; + Status GetNL(ServerContext* context, const modyn::storage::GetRequest* request, + ServerWriter* writer) override; Status GetNewDataSince(ServerContext* context, const modyn::storage::GetNewDataSinceRequest* request, ServerWriter* writer) override; Status GetDataInInterval(ServerContext* context, const modyn::storage::GetDataInIntervalRequest* request, @@ -134,6 +136,55 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { return {StatusCode::INTERNAL, fmt::format("Error in Get: {}", e.what())}; } } + //IMPORTANT SECOND FUNCTION STARTS HERE + template + Status Get_Impl_NL( // NOLINT (readability-identifier-naming) + ServerContext* /*context*/, const modyn::storage::GetRequest* request, WriterT* writer) { + try { + soci::session session = storage_database_connection_.get_session(); + + // Check if the dataset exists + std::string dataset_name = request->dataset_id(); + const DatasetData dataset_data = get_dataset_data(session, dataset_name); + + SPDLOG_INFO(fmt::format("Received GetRequest for dataset {} (id = {}) with {} keys.", dataset_name, + dataset_data.dataset_id, request->keys_size())); + + if (dataset_data.dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", request->dataset_id()); + session.close(); + return {StatusCode::OK, "Dataset does not exist."}; + } + + const auto keys_size = static_cast(request->keys_size()); + if (keys_size == 0) { + return {StatusCode::OK, "No keys provided."}; + } + + std::vector request_keys; + request_keys.reserve(keys_size); + std::copy(request->keys().begin(), request->keys().end(), std::back_inserter(request_keys)); + + send_sample_data_from_keys_NL(writer, request_keys, dataset_data); //Llegamos hasta aqui + + // sqlite causes memory leaks otherwise + if (session.get_backend_name() != "sqlite3" && session.is_connected()) { + session.close(); + } + + return {StatusCode::OK, "Data retrieved."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in Get: {}", e.what()); + return {StatusCode::INTERNAL, fmt::format("Error in Get: {}", e.what())}; + } + } + //ENDS HERE + + + + + + template Status GetNewDataSince_Impl( // NOLINT (readability-identifier-naming) @@ -365,6 +416,66 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { } } } + //EMPIEZA + template > + void send_sample_data_from_keys_NL(WriterT* writer, const std::vector& request_keys, + const DatasetData& dataset_data) { + // Create mutex to protect the writer from concurrent writes as this is not supported by gRPC + std::mutex writer_mutex; + + if (disable_multithreading_) { + const std::vector::const_iterator begin = request_keys.begin(); // NOLINT (modernize-use-auto) + const std::vector::const_iterator end = request_keys.end(); // NOLINT (modernize-use-auto) + + get_samples_and_send_NL(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_);//llegamos aqui + + } else { + std::vector thread_exceptions(retrieval_threads_); + std::mutex exception_mutex; + std::vector::const_iterator, std::vector::const_iterator>> + its_per_thread = get_keys_per_thread(request_keys, retrieval_threads_); + std::vector retrieval_threads_vector(retrieval_threads_); + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + const std::vector::const_iterator begin = its_per_thread[thread_id].first; + const std::vector::const_iterator end = its_per_thread[thread_id].second; + + retrieval_threads_vector[thread_id] = std::thread([thread_id, begin, end, writer, &writer_mutex, &dataset_data, + &thread_exceptions, &exception_mutex, this]() { + try { + get_samples_and_send_NL(begin, end, writer, &writer_mutex, &dataset_data, &config_, + sample_batch_size_); + } catch (const std::exception& e) { + const std::lock_guard lock(exception_mutex); + spdlog::error( + fmt::format("Error in thread {} started by send_sample_data_from_keys: {}", thread_id, e.what())); + thread_exceptions[thread_id] = std::current_exception(); + } + }); + } + + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + if (retrieval_threads_vector[thread_id].joinable()) { + retrieval_threads_vector[thread_id].join(); + } + } + retrieval_threads_vector.clear(); + // In order for the gRPC call to return an error, we need to rethrow the threaded exceptions. + for (auto& e_ptr : thread_exceptions) { + if (e_ptr) { + try { + std::rethrow_exception(e_ptr); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error while unwinding thread: {}\nPropagating it up the call chain.", e.what()); + throw; + } + } + } + } + } +//ACABA AQUI + + + template > void send_file_ids_and_labels(WriterT* writer, const int64_t dataset_id, const int64_t start_timestamp = -1, @@ -690,6 +801,183 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { } } + + + + + + //"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +//"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +// Adapted function to send sample data without labels +template > +static void send_sample_data_for_keys_and_file_NL( // NOLINT(readability-function-cognitive-complexity) //TODO Adaptar esto + WriterT* writer, std::mutex& writer_mutex, const std::vector& sample_keys, + const DatasetData& dataset_data, soci::session& session, int64_t /*sample_batch_size*/) { + + // Note that we currently ignore the sample batch size here, under the assumption that users do not request more + // keys than this + try { + const uint64_t num_keys = sample_keys.size(); + + if (num_keys == 0) { + SPDLOG_ERROR("num_keys is 0, this should not have happened. Exiting send_sample_data_for_keys_and_file_NL"); + return; + } + + // Removed labels-related vectors + // std::vector sample_labels(num_keys); + std::vector sample_indices(num_keys); + std::vector sample_fileids(num_keys); + + const std::string sample_query = fmt::format( + "SELECT sample_index, file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN ({}) ORDER BY file_id", + fmt::join(sample_keys, ",")); + session << sample_query, + soci::into(sample_indices), + soci::into(sample_fileids), + soci::use(dataset_data.dataset_id); + + if (sample_fileids.size() != num_keys) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException( + fmt::format("Got back {} samples from DB, while asking for {} keys. You might have asked for duplicate " + "keys, which is not supported.", + sample_fileids.size(), num_keys)); + } + + int64_t current_file_id = sample_fileids.at(0); + uint64_t current_file_start_idx = 0; + std::string current_file_path; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + SPDLOG_ERROR( + fmt::format("num_keys = {}, current_file_id = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, current_file_id, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config); + auto filesystem_wrapper = + get_filesystem_wrapper(static_cast(dataset_data.filesystem_wrapper_type)); + + auto file_wrapper = + get_file_wrapper(current_file_path, static_cast(dataset_data.file_wrapper_type), + file_wrapper_config_node, filesystem_wrapper); + + for (uint64_t sample_idx = 0; sample_idx < num_keys; ++sample_idx) { + const int64_t& sample_fileid = sample_fileids.at(sample_idx); + + if (sample_fileid != current_file_id) { + // 1. Prepare response without labels + const std::vector file_indexes( + sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.begin() + static_cast(sample_idx)); + std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + data.clear(); + data.shrink_to_fit(); + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.begin() + static_cast(sample_idx)); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.begin() + static_cast(sample_idx)); + + // 2. Send response + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + + // 3. Update state + current_file_id = sample_fileid; + current_file_path = ""; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + const int64_t& previous_fid = sample_fileids.at(sample_idx - 1); + SPDLOG_ERROR( + fmt::format("num_keys = {}, sample_idx = {}, previous_fid = {}\n sample_indices = [{}]\n sample_fileids = [{}]", + num_keys, sample_idx, previous_fid, + fmt::join(sample_indices, ", "), + fmt::join(sample_fileids, ", "))); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } + file_wrapper->set_file_path(current_file_path); + current_file_start_idx = sample_idx; + } + } + + // Send leftovers without labels + const std::vector file_indexes(sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.end()); + const std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + + // Changed GetResponse to GetResponseNoLabels + modyn::storage::GetResponseNoLabels response; // <-- Changed from GetResponse + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.end()); + // Removed labels assignment + // response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + // sample_labels.end()); + + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); // <-- Correct type: GetResponseNoLabels + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in send_sample_data_for_keys_and_file_NL: {}", e.what()); + SPDLOG_ERROR("Propagating error up the call chain to handle gRPC calls."); + throw; + } +} + + //"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + + + + + + + + + + + + + + + + template static void get_samples_and_send(const std::vector::const_iterator begin, const std::vector::const_iterator end, WriterT* writer, @@ -705,7 +993,21 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { sample_batch_size); session.close(); } - + template + static void get_samples_and_send_NL(const std::vector::const_iterator begin, + const std::vector::const_iterator end, WriterT* writer, + std::mutex* writer_mutex, const DatasetData* dataset_data, const YAML::Node* config, + int64_t sample_batch_size) { + if (begin >= end) { + return; + } + const StorageDatabaseConnection storage_database_connection(*config); + soci::session session = storage_database_connection.get_session(); + const std::vector sample_keys(begin, end); + send_sample_data_for_keys_and_file_NL(writer, *writer_mutex, sample_keys, *dataset_data, session, + sample_batch_size); + session.close(); + } static std::string get_timestamp_condition(const int64_t start_timestamp = -1, const int64_t end_timestamp = -1) { std::string timestamp_filter; if (start_timestamp >= 0 && end_timestamp == -1) { diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.py b/modyn/storage/internal/grpc/generated/storage_pb2.py index 8d59a8b1f..b24fd735d 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: storage.proto -# Protobuf Python Version: 5.26.1 +# Protobuf Python Version: 5.27.2 """Generated protocol buffer code.""" - from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool + from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder @@ -14,55 +15,57 @@ _sym_db = _symbol_database.Default() + from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\rstorage.proto\x12\rmodyn.storage".\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\x1c\n\x1aGetCurrentTimestampRequest"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe2\x07\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse"\x00\x62\x06proto3' -) + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rstorage.proto\x12\rmodyn.storage\"F\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x16\n\x0einclude_labels\x18\x03 \x01(\x08\"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"4\n\x13GetResponseNoLabels\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"\x1c\n\x1aGetCurrentTimestampRequest\"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03\"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03\"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03\"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03\"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xae\x08\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse\"\x00\x30\x01\x12J\n\x05GetNL\x12\x19.modyn.storage.GetRequest\x1a\".modyn.storage.GetResponseNoLabels\"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse\"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse\"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse\"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse\"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse\"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse\"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse\"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse\"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "storage_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'storage_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals["_GETREQUEST"]._serialized_start = 32 - _globals["_GETREQUEST"]._serialized_end = 78 - _globals["_GETRESPONSE"]._serialized_start = 80 - _globals["_GETRESPONSE"]._serialized_end = 140 - _globals["_GETCURRENTTIMESTAMPREQUEST"]._serialized_start = 142 - _globals["_GETCURRENTTIMESTAMPREQUEST"]._serialized_end = 170 - _globals["_GETNEWDATASINCEREQUEST"]._serialized_start = 172 - _globals["_GETNEWDATASINCEREQUEST"]._serialized_end = 235 - _globals["_GETNEWDATASINCERESPONSE"]._serialized_start = 237 - _globals["_GETNEWDATASINCERESPONSE"]._serialized_end = 312 - _globals["_GETDATAININTERVALREQUEST"]._serialized_start = 314 - _globals["_GETDATAININTERVALREQUEST"]._serialized_end = 408 - _globals["_GETDATAININTERVALRESPONSE"]._serialized_start = 410 - _globals["_GETDATAININTERVALRESPONSE"]._serialized_end = 487 - _globals["_GETDATAPERWORKERREQUEST"]._serialized_start = 490 - _globals["_GETDATAPERWORKERREQUEST"]._serialized_end = 673 - _globals["_GETDATAPERWORKERRESPONSE"]._serialized_start = 675 - _globals["_GETDATAPERWORKERRESPONSE"]._serialized_end = 715 - _globals["_GETDATASETSIZEREQUEST"]._serialized_start = 718 - _globals["_GETDATASETSIZEREQUEST"]._serialized_end = 857 - _globals["_GETDATASETSIZERESPONSE"]._serialized_start = 859 - _globals["_GETDATASETSIZERESPONSE"]._serialized_end = 918 - _globals["_DATASETAVAILABLEREQUEST"]._serialized_start = 920 - _globals["_DATASETAVAILABLEREQUEST"]._serialized_end = 965 - _globals["_DATASETAVAILABLERESPONSE"]._serialized_start = 967 - _globals["_DATASETAVAILABLERESPONSE"]._serialized_end = 1012 - _globals["_REGISTERNEWDATASETREQUEST"]._serialized_start = 1015 - _globals["_REGISTERNEWDATASETREQUEST"]._serialized_end = 1270 - _globals["_REGISTERNEWDATASETRESPONSE"]._serialized_start = 1272 - _globals["_REGISTERNEWDATASETRESPONSE"]._serialized_end = 1317 - _globals["_GETCURRENTTIMESTAMPRESPONSE"]._serialized_start = 1319 - _globals["_GETCURRENTTIMESTAMPRESPONSE"]._serialized_end = 1367 - _globals["_DELETEDATASETRESPONSE"]._serialized_start = 1369 - _globals["_DELETEDATASETRESPONSE"]._serialized_end = 1409 - _globals["_DELETEDATAREQUEST"]._serialized_start = 1411 - _globals["_DELETEDATAREQUEST"]._serialized_end = 1464 - _globals["_DELETEDATARESPONSE"]._serialized_start = 1466 - _globals["_DELETEDATARESPONSE"]._serialized_end = 1503 - _globals["_STORAGE"]._serialized_start = 1506 - _globals["_STORAGE"]._serialized_end = 2500 + DESCRIPTOR._loaded_options = None + _globals['_GETREQUEST']._serialized_start=32 + _globals['_GETREQUEST']._serialized_end=102 + _globals['_GETRESPONSE']._serialized_start=104 + _globals['_GETRESPONSE']._serialized_end=164 + _globals['_GETRESPONSENOLABELS']._serialized_start=166 + _globals['_GETRESPONSENOLABELS']._serialized_end=218 + _globals['_GETCURRENTTIMESTAMPREQUEST']._serialized_start=220 + _globals['_GETCURRENTTIMESTAMPREQUEST']._serialized_end=248 + _globals['_GETNEWDATASINCEREQUEST']._serialized_start=250 + _globals['_GETNEWDATASINCEREQUEST']._serialized_end=313 + _globals['_GETNEWDATASINCERESPONSE']._serialized_start=315 + _globals['_GETNEWDATASINCERESPONSE']._serialized_end=390 + _globals['_GETDATAININTERVALREQUEST']._serialized_start=392 + _globals['_GETDATAININTERVALREQUEST']._serialized_end=486 + _globals['_GETDATAININTERVALRESPONSE']._serialized_start=488 + _globals['_GETDATAININTERVALRESPONSE']._serialized_end=565 + _globals['_GETDATAPERWORKERREQUEST']._serialized_start=568 + _globals['_GETDATAPERWORKERREQUEST']._serialized_end=751 + _globals['_GETDATAPERWORKERRESPONSE']._serialized_start=753 + _globals['_GETDATAPERWORKERRESPONSE']._serialized_end=793 + _globals['_GETDATASETSIZEREQUEST']._serialized_start=796 + _globals['_GETDATASETSIZEREQUEST']._serialized_end=935 + _globals['_GETDATASETSIZERESPONSE']._serialized_start=937 + _globals['_GETDATASETSIZERESPONSE']._serialized_end=996 + _globals['_DATASETAVAILABLEREQUEST']._serialized_start=998 + _globals['_DATASETAVAILABLEREQUEST']._serialized_end=1043 + _globals['_DATASETAVAILABLERESPONSE']._serialized_start=1045 + _globals['_DATASETAVAILABLERESPONSE']._serialized_end=1090 + _globals['_REGISTERNEWDATASETREQUEST']._serialized_start=1093 + _globals['_REGISTERNEWDATASETREQUEST']._serialized_end=1348 + _globals['_REGISTERNEWDATASETRESPONSE']._serialized_start=1350 + _globals['_REGISTERNEWDATASETRESPONSE']._serialized_end=1395 + _globals['_GETCURRENTTIMESTAMPRESPONSE']._serialized_start=1397 + _globals['_GETCURRENTTIMESTAMPRESPONSE']._serialized_end=1445 + _globals['_DELETEDATASETRESPONSE']._serialized_start=1447 + _globals['_DELETEDATASETRESPONSE']._serialized_end=1487 + _globals['_DELETEDATAREQUEST']._serialized_start=1489 + _globals['_DELETEDATAREQUEST']._serialized_end=1542 + _globals['_DELETEDATARESPONSE']._serialized_start=1544 + _globals['_DELETEDATARESPONSE']._serialized_end=1581 + _globals['_STORAGE']._serialized_start=1584 + _globals['_STORAGE']._serialized_end=2654 # @@protoc_insertion_point(module_scope) diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.pyi b/modyn/storage/internal/grpc/generated/storage_pb2.pyi index cf286ada1..a6c340801 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.pyi +++ b/modyn/storage/internal/grpc/generated/storage_pb2.pyi @@ -18,7 +18,10 @@ class GetRequest(google.protobuf.message.Message): DATASET_ID_FIELD_NUMBER: builtins.int KEYS_FIELD_NUMBER: builtins.int + INCLUDE_LABELS_FIELD_NUMBER: builtins.int dataset_id: builtins.str + include_labels: builtins.bool + """Added this line""" @property def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__( @@ -26,8 +29,9 @@ class GetRequest(google.protobuf.message.Message): *, dataset_id: builtins.str = ..., keys: collections.abc.Iterable[builtins.int] | None = ..., + include_labels: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "include_labels", b"include_labels", "keys", b"keys"]) -> None: ... global___GetRequest = GetRequest @@ -51,12 +55,30 @@ class GetResponse(google.protobuf.message.Message): keys: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"]) -> None: ... global___GetResponse = GetResponse +@typing.final +class GetResponseNoLabels(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SAMPLES_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + @property + def samples(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + samples: collections.abc.Iterable[builtins.bytes] | None = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "samples", b"samples"]) -> None: ... + +global___GetResponseNoLabels = GetResponseNoLabels + @typing.final class GetCurrentTimestampRequest(google.protobuf.message.Message): """https://github.com/grpc/grpc/issues/15937""" @@ -83,9 +105,7 @@ class GetNewDataSinceRequest(google.protobuf.message.Message): dataset_id: builtins.str = ..., timestamp: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"]) -> None: ... global___GetNewDataSinceRequest = GetNewDataSinceRequest @@ -109,9 +129,7 @@ class GetNewDataSinceResponse(google.protobuf.message.Message): timestamps: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... global___GetNewDataSinceResponse = GetNewDataSinceResponse @@ -132,12 +150,7 @@ class GetDataInIntervalRequest(google.protobuf.message.Message): start_timestamp: builtins.int = ..., end_timestamp: builtins.int = ..., ) -> None: ... - def ClearField( - self, - field_name: typing.Literal[ - "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp" - ], - ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... global___GetDataInIntervalRequest = GetDataInIntervalRequest @@ -161,9 +174,7 @@ class GetDataInIntervalResponse(google.protobuf.message.Message): timestamps: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... global___GetDataInIntervalResponse = GetDataInIntervalResponse @@ -193,46 +204,12 @@ class GetDataPerWorkerRequest(google.protobuf.message.Message): start_timestamp: builtins.int | None = ..., end_timestamp: builtins.int | None = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "dataset_id", - b"dataset_id", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - "total_workers", - b"total_workers", - "worker_id", - b"worker_id", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp", "total_workers", b"total_workers", "worker_id", b"worker_id"]) -> None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"] - ) -> typing.Literal["end_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"] - ) -> typing.Literal["start_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... global___GetDataPerWorkerRequest = GetDataPerWorkerRequest @@ -272,42 +249,12 @@ class GetDatasetSizeRequest(google.protobuf.message.Message): start_timestamp: builtins.int | None = ..., end_timestamp: builtins.int | None = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_end_timestamp", - b"_end_timestamp", - "_start_timestamp", - b"_start_timestamp", - "dataset_id", - b"dataset_id", - "end_timestamp", - b"end_timestamp", - "start_timestamp", - b"start_timestamp", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"] - ) -> typing.Literal["end_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"] - ) -> typing.Literal["start_timestamp"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... global___GetDatasetSizeRequest = GetDatasetSizeRequest @@ -394,29 +341,7 @@ class RegisterNewDatasetRequest(google.protobuf.message.Message): ignore_last_timestamp: builtins.bool = ..., file_watcher_interval: builtins.int = ..., ) -> None: ... - def ClearField( - self, - field_name: typing.Literal[ - "base_path", - b"base_path", - "dataset_id", - b"dataset_id", - "description", - b"description", - "file_watcher_interval", - b"file_watcher_interval", - "file_wrapper_config", - b"file_wrapper_config", - "file_wrapper_type", - b"file_wrapper_type", - "filesystem_wrapper_type", - b"filesystem_wrapper_type", - "ignore_last_timestamp", - b"ignore_last_timestamp", - "version", - b"version", - ], - ) -> None: ... + def ClearField(self, field_name: typing.Literal["base_path", b"base_path", "dataset_id", b"dataset_id", "description", b"description", "file_watcher_interval", b"file_watcher_interval", "file_wrapper_config", b"file_wrapper_config", "file_wrapper_type", b"file_wrapper_type", "filesystem_wrapper_type", b"filesystem_wrapper_type", "ignore_last_timestamp", b"ignore_last_timestamp", "version", b"version"]) -> None: ... global___RegisterNewDatasetRequest = RegisterNewDatasetRequest diff --git a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py index cae27014a..7cda0b538 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py @@ -1,34 +1,27 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" - +import grpc import warnings -import grpc import modyn.storage.internal.grpc.generated.storage_pb2 as storage__pb2 -GRPC_GENERATED_VERSION = "1.63.0" +GRPC_GENERATED_VERSION = '1.67.1' GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" _version_not_supported = False try: from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) except ImportError: _version_not_supported = True if _version_not_supported: - warnings.warn( - f"The grpc package installed is at version {GRPC_VERSION}," - + f" but the generated code in storage_pb2_grpc.py depends on" - + f" grpcio>={GRPC_GENERATED_VERSION}." - + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" - + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in storage_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' ) @@ -42,65 +35,60 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Get = channel.unary_stream( - "/modyn.storage.Storage/Get", - request_serializer=storage__pb2.GetRequest.SerializeToString, - response_deserializer=storage__pb2.GetResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/Get', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponse.FromString, + _registered_method=True) + self.GetNL = channel.unary_stream( + '/modyn.storage.Storage/GetNL', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponseNoLabels.FromString, + _registered_method=True) self.GetNewDataSince = channel.unary_stream( - "/modyn.storage.Storage/GetNewDataSince", - request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, - response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetNewDataSince', + request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, + response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, + _registered_method=True) self.GetDataInInterval = channel.unary_stream( - "/modyn.storage.Storage/GetDataInInterval", - request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, - response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetDataInInterval', + request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, + _registered_method=True) self.GetDataPerWorker = channel.unary_stream( - "/modyn.storage.Storage/GetDataPerWorker", - request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, - response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetDataPerWorker', + request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, + _registered_method=True) self.GetDatasetSize = channel.unary_unary( - "/modyn.storage.Storage/GetDatasetSize", - request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, - response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetDatasetSize', + request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, + response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, + _registered_method=True) self.CheckAvailability = channel.unary_unary( - "/modyn.storage.Storage/CheckAvailability", - request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, - response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/CheckAvailability', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, + _registered_method=True) self.RegisterNewDataset = channel.unary_unary( - "/modyn.storage.Storage/RegisterNewDataset", - request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, - response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/RegisterNewDataset', + request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, + response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, + _registered_method=True) self.GetCurrentTimestamp = channel.unary_unary( - "/modyn.storage.Storage/GetCurrentTimestamp", - request_serializer=storage__pb2.GetCurrentTimestampRequest.SerializeToString, - response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/GetCurrentTimestamp', + request_serializer=storage__pb2.GetCurrentTimestampRequest.SerializeToString, + response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, + _registered_method=True) self.DeleteDataset = channel.unary_unary( - "/modyn.storage.Storage/DeleteDataset", - request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, - response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/DeleteDataset', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, + _registered_method=True) self.DeleteData = channel.unary_unary( - "/modyn.storage.Storage/DeleteData", - request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, - response_deserializer=storage__pb2.DeleteDataResponse.FromString, - _registered_method=True, - ) + '/modyn.storage.Storage/DeleteData', + request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDataResponse.FromString, + _registered_method=True) class StorageServicer(object): @@ -109,142 +97,153 @@ class StorageServicer(object): def Get(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetNL(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetNewDataSince(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetDataInInterval(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetDataPerWorker(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetDatasetSize(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def CheckAvailability(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def RegisterNewDataset(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GetCurrentTimestamp(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def DeleteDataset(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def DeleteData(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_StorageServicer_to_server(servicer, server): rpc_method_handlers = { - "Get": grpc.unary_stream_rpc_method_handler( - servicer.Get, - request_deserializer=storage__pb2.GetRequest.FromString, - response_serializer=storage__pb2.GetResponse.SerializeToString, - ), - "GetNewDataSince": grpc.unary_stream_rpc_method_handler( - servicer.GetNewDataSince, - request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, - response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, - ), - "GetDataInInterval": grpc.unary_stream_rpc_method_handler( - servicer.GetDataInInterval, - request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, - response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, - ), - "GetDataPerWorker": grpc.unary_stream_rpc_method_handler( - servicer.GetDataPerWorker, - request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, - response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, - ), - "GetDatasetSize": grpc.unary_unary_rpc_method_handler( - servicer.GetDatasetSize, - request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, - response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, - ), - "CheckAvailability": grpc.unary_unary_rpc_method_handler( - servicer.CheckAvailability, - request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, - response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, - ), - "RegisterNewDataset": grpc.unary_unary_rpc_method_handler( - servicer.RegisterNewDataset, - request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, - response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, - ), - "GetCurrentTimestamp": grpc.unary_unary_rpc_method_handler( - servicer.GetCurrentTimestamp, - request_deserializer=storage__pb2.GetCurrentTimestampRequest.FromString, - response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, - ), - "DeleteDataset": grpc.unary_unary_rpc_method_handler( - servicer.DeleteDataset, - request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, - response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, - ), - "DeleteData": grpc.unary_unary_rpc_method_handler( - servicer.DeleteData, - request_deserializer=storage__pb2.DeleteDataRequest.FromString, - response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, - ), + 'Get': grpc.unary_stream_rpc_method_handler( + servicer.Get, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponse.SerializeToString, + ), + 'GetNL': grpc.unary_stream_rpc_method_handler( + servicer.GetNL, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponseNoLabels.SerializeToString, + ), + 'GetNewDataSince': grpc.unary_stream_rpc_method_handler( + servicer.GetNewDataSince, + request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, + response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, + ), + 'GetDataInInterval': grpc.unary_stream_rpc_method_handler( + servicer.GetDataInInterval, + request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, + response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, + ), + 'GetDataPerWorker': grpc.unary_stream_rpc_method_handler( + servicer.GetDataPerWorker, + request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, + response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, + ), + 'GetDatasetSize': grpc.unary_unary_rpc_method_handler( + servicer.GetDatasetSize, + request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, + response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, + ), + 'CheckAvailability': grpc.unary_unary_rpc_method_handler( + servicer.CheckAvailability, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, + ), + 'RegisterNewDataset': grpc.unary_unary_rpc_method_handler( + servicer.RegisterNewDataset, + request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, + response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, + ), + 'GetCurrentTimestamp': grpc.unary_unary_rpc_method_handler( + servicer.GetCurrentTimestamp, + request_deserializer=storage__pb2.GetCurrentTimestampRequest.FromString, + response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, + ), + 'DeleteDataset': grpc.unary_unary_rpc_method_handler( + servicer.DeleteDataset, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, + ), + 'DeleteData': grpc.unary_unary_rpc_method_handler( + servicer.DeleteData, + request_deserializer=storage__pb2.DeleteDataRequest.FromString, + response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler("modyn.storage.Storage", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler( + 'modyn.storage.Storage', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('modyn.storage.Storage', rpc_method_handlers) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class Storage(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Get( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def Get(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/Get", + '/modyn.storage.Storage/Get', storage__pb2.GetRequest.SerializeToString, storage__pb2.GetResponse.FromString, options, @@ -255,26 +254,50 @@ def Get( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) + + @staticmethod + def GetNL(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/modyn.storage.Storage/GetNL', + storage__pb2.GetRequest.SerializeToString, + storage__pb2.GetResponseNoLabels.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) @staticmethod - def GetNewDataSince( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetNewDataSince(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/GetNewDataSince", + '/modyn.storage.Storage/GetNewDataSince', storage__pb2.GetNewDataSinceRequest.SerializeToString, storage__pb2.GetNewDataSinceResponse.FromString, options, @@ -285,26 +308,23 @@ def GetNewDataSince( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetDataInInterval( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetDataInInterval(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/GetDataInInterval", + '/modyn.storage.Storage/GetDataInInterval', storage__pb2.GetDataInIntervalRequest.SerializeToString, storage__pb2.GetDataInIntervalResponse.FromString, options, @@ -315,26 +335,23 @@ def GetDataInInterval( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetDataPerWorker( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetDataPerWorker(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream( request, target, - "/modyn.storage.Storage/GetDataPerWorker", + '/modyn.storage.Storage/GetDataPerWorker', storage__pb2.GetDataPerWorkerRequest.SerializeToString, storage__pb2.GetDataPerWorkerResponse.FromString, options, @@ -345,26 +362,23 @@ def GetDataPerWorker( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetDatasetSize( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetDatasetSize(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/GetDatasetSize", + '/modyn.storage.Storage/GetDatasetSize', storage__pb2.GetDatasetSizeRequest.SerializeToString, storage__pb2.GetDatasetSizeResponse.FromString, options, @@ -375,26 +389,23 @@ def GetDatasetSize( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def CheckAvailability( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def CheckAvailability(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/CheckAvailability", + '/modyn.storage.Storage/CheckAvailability', storage__pb2.DatasetAvailableRequest.SerializeToString, storage__pb2.DatasetAvailableResponse.FromString, options, @@ -405,26 +416,23 @@ def CheckAvailability( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def RegisterNewDataset( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def RegisterNewDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/RegisterNewDataset", + '/modyn.storage.Storage/RegisterNewDataset', storage__pb2.RegisterNewDatasetRequest.SerializeToString, storage__pb2.RegisterNewDatasetResponse.FromString, options, @@ -435,26 +443,23 @@ def RegisterNewDataset( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def GetCurrentTimestamp( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GetCurrentTimestamp(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/GetCurrentTimestamp", + '/modyn.storage.Storage/GetCurrentTimestamp', storage__pb2.GetCurrentTimestampRequest.SerializeToString, storage__pb2.GetCurrentTimestampResponse.FromString, options, @@ -465,26 +470,23 @@ def GetCurrentTimestamp( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def DeleteDataset( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def DeleteDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/DeleteDataset", + '/modyn.storage.Storage/DeleteDataset', storage__pb2.DatasetAvailableRequest.SerializeToString, storage__pb2.DeleteDatasetResponse.FromString, options, @@ -495,26 +497,23 @@ def DeleteDataset( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) @staticmethod - def DeleteData( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def DeleteData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modyn.storage.Storage/DeleteData", + '/modyn.storage.Storage/DeleteData', storage__pb2.DeleteDataRequest.SerializeToString, storage__pb2.DeleteDataResponse.FromString, options, @@ -525,5 +524,4 @@ def DeleteData( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) diff --git a/modyn/storage/src/internal/database/grpc2/storage_pb2.py b/modyn/storage/src/internal/database/grpc2/storage_pb2.py new file mode 100644 index 000000000..18ff61e95 --- /dev/null +++ b/modyn/storage/src/internal/database/grpc2/storage_pb2.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: storage.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rstorage.proto\x12\rmodyn.storage\"F\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x16\n\x0einclude_labels\x18\x03 \x01(\x08\"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"4\n\x13GetResponseNoLabels\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"\x1c\n\x1aGetCurrentTimestampRequest\"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03\"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp\";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03\"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03\"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03\"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xae\x08\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse\"\x00\x30\x01\x12J\n\x05GetNL\x12\x19.modyn.storage.GetRequest\x1a\".modyn.storage.GetResponseNoLabels\"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse\"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse\"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse\"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse\"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse\"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse\"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse\"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse\"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse\"\x00\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'storage_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _GETREQUEST._serialized_start=32 + _GETREQUEST._serialized_end=102 + _GETRESPONSE._serialized_start=104 + _GETRESPONSE._serialized_end=164 + _GETRESPONSENOLABELS._serialized_start=166 + _GETRESPONSENOLABELS._serialized_end=218 + _GETCURRENTTIMESTAMPREQUEST._serialized_start=220 + _GETCURRENTTIMESTAMPREQUEST._serialized_end=248 + _GETNEWDATASINCEREQUEST._serialized_start=250 + _GETNEWDATASINCEREQUEST._serialized_end=313 + _GETNEWDATASINCERESPONSE._serialized_start=315 + _GETNEWDATASINCERESPONSE._serialized_end=390 + _GETDATAININTERVALREQUEST._serialized_start=392 + _GETDATAININTERVALREQUEST._serialized_end=486 + _GETDATAININTERVALRESPONSE._serialized_start=488 + _GETDATAININTERVALRESPONSE._serialized_end=565 + _GETDATAPERWORKERREQUEST._serialized_start=568 + _GETDATAPERWORKERREQUEST._serialized_end=751 + _GETDATAPERWORKERRESPONSE._serialized_start=753 + _GETDATAPERWORKERRESPONSE._serialized_end=793 + _GETDATASETSIZEREQUEST._serialized_start=796 + _GETDATASETSIZEREQUEST._serialized_end=935 + _GETDATASETSIZERESPONSE._serialized_start=937 + _GETDATASETSIZERESPONSE._serialized_end=996 + _DATASETAVAILABLEREQUEST._serialized_start=998 + _DATASETAVAILABLEREQUEST._serialized_end=1043 + _DATASETAVAILABLERESPONSE._serialized_start=1045 + _DATASETAVAILABLERESPONSE._serialized_end=1090 + _REGISTERNEWDATASETREQUEST._serialized_start=1093 + _REGISTERNEWDATASETREQUEST._serialized_end=1348 + _REGISTERNEWDATASETRESPONSE._serialized_start=1350 + _REGISTERNEWDATASETRESPONSE._serialized_end=1395 + _GETCURRENTTIMESTAMPRESPONSE._serialized_start=1397 + _GETCURRENTTIMESTAMPRESPONSE._serialized_end=1445 + _DELETEDATASETRESPONSE._serialized_start=1447 + _DELETEDATASETRESPONSE._serialized_end=1487 + _DELETEDATAREQUEST._serialized_start=1489 + _DELETEDATAREQUEST._serialized_end=1542 + _DELETEDATARESPONSE._serialized_start=1544 + _DELETEDATARESPONSE._serialized_end=1581 + _STORAGE._serialized_start=1584 + _STORAGE._serialized_end=2654 +# @@protoc_insertion_point(module_scope) diff --git a/modyn/storage/src/internal/database/grpc2/storage_pb2.pyi b/modyn/storage/src/internal/database/grpc2/storage_pb2.pyi new file mode 100644 index 000000000..a6c340801 --- /dev/null +++ b/modyn/storage/src/internal/database/grpc2/storage_pb2.pyi @@ -0,0 +1,425 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class GetRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + INCLUDE_LABELS_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + include_labels: builtins.bool + """Added this line""" + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + dataset_id: builtins.str = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + include_labels: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "include_labels", b"include_labels", "keys", b"keys"]) -> None: ... + +global___GetRequest = GetRequest + +@typing.final +class GetResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SAMPLES_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + @property + def samples(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + samples: collections.abc.Iterable[builtins.bytes] | None = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + labels: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"]) -> None: ... + +global___GetResponse = GetResponse + +@typing.final +class GetResponseNoLabels(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SAMPLES_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + @property + def samples(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + samples: collections.abc.Iterable[builtins.bytes] | None = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "samples", b"samples"]) -> None: ... + +global___GetResponseNoLabels = GetResponseNoLabels + +@typing.final +class GetCurrentTimestampRequest(google.protobuf.message.Message): + """https://github.com/grpc/grpc/issues/15937""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___GetCurrentTimestampRequest = GetCurrentTimestampRequest + +@typing.final +class GetNewDataSinceRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + timestamp: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"]) -> None: ... + +global___GetNewDataSinceRequest = GetNewDataSinceRequest + +@typing.final +class GetNewDataSinceResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + TIMESTAMPS_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def timestamps(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.int] | None = ..., + timestamps: collections.abc.Iterable[builtins.int] | None = ..., + labels: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... + +global___GetNewDataSinceResponse = GetNewDataSinceResponse + +@typing.final +class GetDataInIntervalRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + START_TIMESTAMP_FIELD_NUMBER: builtins.int + END_TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + start_timestamp: builtins.int + end_timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + start_timestamp: builtins.int = ..., + end_timestamp: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... + +global___GetDataInIntervalRequest = GetDataInIntervalRequest + +@typing.final +class GetDataInIntervalResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + TIMESTAMPS_FIELD_NUMBER: builtins.int + LABELS_FIELD_NUMBER: builtins.int + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def timestamps(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + @property + def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.int] | None = ..., + timestamps: collections.abc.Iterable[builtins.int] | None = ..., + labels: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... + +global___GetDataInIntervalResponse = GetDataInIntervalResponse + +@typing.final +class GetDataPerWorkerRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + WORKER_ID_FIELD_NUMBER: builtins.int + TOTAL_WORKERS_FIELD_NUMBER: builtins.int + START_TIMESTAMP_FIELD_NUMBER: builtins.int + END_TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + worker_id: builtins.int + total_workers: builtins.int + start_timestamp: builtins.int + """value unset means no limit + start_timestamp is inclusive, end_timestamp is exclusive + """ + end_timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + worker_id: builtins.int = ..., + total_workers: builtins.int = ..., + start_timestamp: builtins.int | None = ..., + end_timestamp: builtins.int | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp", "total_workers", b"total_workers", "worker_id", b"worker_id"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... + +global___GetDataPerWorkerRequest = GetDataPerWorkerRequest + +@typing.final +class GetDataPerWorkerResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["keys", b"keys"]) -> None: ... + +global___GetDataPerWorkerResponse = GetDataPerWorkerResponse + +@typing.final +class GetDatasetSizeRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + START_TIMESTAMP_FIELD_NUMBER: builtins.int + END_TIMESTAMP_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + start_timestamp: builtins.int + """value unset means no limit + start_timestamp is inclusive, end_timestamp is exclusive + """ + end_timestamp: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + start_timestamp: builtins.int | None = ..., + end_timestamp: builtins.int | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_end_timestamp", b"_end_timestamp", "_start_timestamp", b"_start_timestamp", "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"]) -> typing.Literal["end_timestamp"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"]) -> typing.Literal["start_timestamp"] | None: ... + +global___GetDatasetSizeRequest = GetDatasetSizeRequest + +@typing.final +class GetDatasetSizeResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + NUM_KEYS_FIELD_NUMBER: builtins.int + success: builtins.bool + num_keys: builtins.int + def __init__( + self, + *, + success: builtins.bool = ..., + num_keys: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["num_keys", b"num_keys", "success", b"success"]) -> None: ... + +global___GetDatasetSizeResponse = GetDatasetSizeResponse + +@typing.final +class DatasetAvailableRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + def __init__( + self, + *, + dataset_id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id"]) -> None: ... + +global___DatasetAvailableRequest = DatasetAvailableRequest + +@typing.final +class DatasetAvailableResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + AVAILABLE_FIELD_NUMBER: builtins.int + available: builtins.bool + def __init__( + self, + *, + available: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["available", b"available"]) -> None: ... + +global___DatasetAvailableResponse = DatasetAvailableResponse + +@typing.final +class RegisterNewDatasetRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + FILESYSTEM_WRAPPER_TYPE_FIELD_NUMBER: builtins.int + FILE_WRAPPER_TYPE_FIELD_NUMBER: builtins.int + DESCRIPTION_FIELD_NUMBER: builtins.int + BASE_PATH_FIELD_NUMBER: builtins.int + VERSION_FIELD_NUMBER: builtins.int + FILE_WRAPPER_CONFIG_FIELD_NUMBER: builtins.int + IGNORE_LAST_TIMESTAMP_FIELD_NUMBER: builtins.int + FILE_WATCHER_INTERVAL_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + filesystem_wrapper_type: builtins.str + file_wrapper_type: builtins.str + description: builtins.str + base_path: builtins.str + version: builtins.str + file_wrapper_config: builtins.str + ignore_last_timestamp: builtins.bool + file_watcher_interval: builtins.int + def __init__( + self, + *, + dataset_id: builtins.str = ..., + filesystem_wrapper_type: builtins.str = ..., + file_wrapper_type: builtins.str = ..., + description: builtins.str = ..., + base_path: builtins.str = ..., + version: builtins.str = ..., + file_wrapper_config: builtins.str = ..., + ignore_last_timestamp: builtins.bool = ..., + file_watcher_interval: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["base_path", b"base_path", "dataset_id", b"dataset_id", "description", b"description", "file_watcher_interval", b"file_watcher_interval", "file_wrapper_config", b"file_wrapper_config", "file_wrapper_type", b"file_wrapper_type", "filesystem_wrapper_type", b"filesystem_wrapper_type", "ignore_last_timestamp", b"ignore_last_timestamp", "version", b"version"]) -> None: ... + +global___RegisterNewDatasetRequest = RegisterNewDatasetRequest + +@typing.final +class RegisterNewDatasetResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__( + self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + +global___RegisterNewDatasetResponse = RegisterNewDatasetResponse + +@typing.final +class GetCurrentTimestampResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TIMESTAMP_FIELD_NUMBER: builtins.int + timestamp: builtins.int + def __init__( + self, + *, + timestamp: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["timestamp", b"timestamp"]) -> None: ... + +global___GetCurrentTimestampResponse = GetCurrentTimestampResponse + +@typing.final +class DeleteDatasetResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__( + self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + +global___DeleteDatasetResponse = DeleteDatasetResponse + +@typing.final +class DeleteDataRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATASET_ID_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + dataset_id: builtins.str + @property + def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__( + self, + *, + dataset_id: builtins.str = ..., + keys: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + +global___DeleteDataRequest = DeleteDataRequest + +@typing.final +class DeleteDataResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__( + self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + +global___DeleteDataResponse = DeleteDataResponse diff --git a/modyn/storage/src/internal/database/grpc2/storage_pb2_grpc.py b/modyn/storage/src/internal/database/grpc2/storage_pb2_grpc.py new file mode 100644 index 000000000..60712130f --- /dev/null +++ b/modyn/storage/src/internal/database/grpc2/storage_pb2_grpc.py @@ -0,0 +1,396 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import storage_pb2 as storage__pb2 + + +class StorageStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Get = channel.unary_stream( + '/modyn.storage.Storage/Get', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponse.FromString, + ) + self.GetNL = channel.unary_stream( + '/modyn.storage.Storage/GetNL', + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponseNoLabels.FromString, + ) + self.GetNewDataSince = channel.unary_stream( + '/modyn.storage.Storage/GetNewDataSince', + request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, + response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, + ) + self.GetDataInInterval = channel.unary_stream( + '/modyn.storage.Storage/GetDataInInterval', + request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, + ) + self.GetDataPerWorker = channel.unary_stream( + '/modyn.storage.Storage/GetDataPerWorker', + request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, + ) + self.GetDatasetSize = channel.unary_unary( + '/modyn.storage.Storage/GetDatasetSize', + request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, + response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, + ) + self.CheckAvailability = channel.unary_unary( + '/modyn.storage.Storage/CheckAvailability', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, + ) + self.RegisterNewDataset = channel.unary_unary( + '/modyn.storage.Storage/RegisterNewDataset', + request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, + response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, + ) + self.GetCurrentTimestamp = channel.unary_unary( + '/modyn.storage.Storage/GetCurrentTimestamp', + request_serializer=storage__pb2.GetCurrentTimestampRequest.SerializeToString, + response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, + ) + self.DeleteDataset = channel.unary_unary( + '/modyn.storage.Storage/DeleteDataset', + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, + ) + self.DeleteData = channel.unary_unary( + '/modyn.storage.Storage/DeleteData', + request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDataResponse.FromString, + ) + + +class StorageServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Get(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetNL(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetNewDataSince(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDataInInterval(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDataPerWorker(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDatasetSize(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CheckAvailability(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RegisterNewDataset(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetCurrentTimestamp(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteDataset(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteData(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_StorageServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Get': grpc.unary_stream_rpc_method_handler( + servicer.Get, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponse.SerializeToString, + ), + 'GetNL': grpc.unary_stream_rpc_method_handler( + servicer.GetNL, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponseNoLabels.SerializeToString, + ), + 'GetNewDataSince': grpc.unary_stream_rpc_method_handler( + servicer.GetNewDataSince, + request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, + response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, + ), + 'GetDataInInterval': grpc.unary_stream_rpc_method_handler( + servicer.GetDataInInterval, + request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, + response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, + ), + 'GetDataPerWorker': grpc.unary_stream_rpc_method_handler( + servicer.GetDataPerWorker, + request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, + response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, + ), + 'GetDatasetSize': grpc.unary_unary_rpc_method_handler( + servicer.GetDatasetSize, + request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, + response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, + ), + 'CheckAvailability': grpc.unary_unary_rpc_method_handler( + servicer.CheckAvailability, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, + ), + 'RegisterNewDataset': grpc.unary_unary_rpc_method_handler( + servicer.RegisterNewDataset, + request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, + response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, + ), + 'GetCurrentTimestamp': grpc.unary_unary_rpc_method_handler( + servicer.GetCurrentTimestamp, + request_deserializer=storage__pb2.GetCurrentTimestampRequest.FromString, + response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, + ), + 'DeleteDataset': grpc.unary_unary_rpc_method_handler( + servicer.DeleteDataset, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, + ), + 'DeleteData': grpc.unary_unary_rpc_method_handler( + servicer.DeleteData, + request_deserializer=storage__pb2.DeleteDataRequest.FromString, + response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'modyn.storage.Storage', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Storage(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Get(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/Get', + storage__pb2.GetRequest.SerializeToString, + storage__pb2.GetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetNL(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetNL', + storage__pb2.GetRequest.SerializeToString, + storage__pb2.GetResponseNoLabels.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetNewDataSince(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetNewDataSince', + storage__pb2.GetNewDataSinceRequest.SerializeToString, + storage__pb2.GetNewDataSinceResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDataInInterval(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetDataInInterval', + storage__pb2.GetDataInIntervalRequest.SerializeToString, + storage__pb2.GetDataInIntervalResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDataPerWorker(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetDataPerWorker', + storage__pb2.GetDataPerWorkerRequest.SerializeToString, + storage__pb2.GetDataPerWorkerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDatasetSize(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/GetDatasetSize', + storage__pb2.GetDatasetSizeRequest.SerializeToString, + storage__pb2.GetDatasetSizeResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CheckAvailability(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/CheckAvailability', + storage__pb2.DatasetAvailableRequest.SerializeToString, + storage__pb2.DatasetAvailableResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def RegisterNewDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/RegisterNewDataset', + storage__pb2.RegisterNewDatasetRequest.SerializeToString, + storage__pb2.RegisterNewDatasetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetCurrentTimestamp(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/GetCurrentTimestamp', + storage__pb2.GetCurrentTimestampRequest.SerializeToString, + storage__pb2.GetCurrentTimestampResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteDataset(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/DeleteDataset', + storage__pb2.DatasetAvailableRequest.SerializeToString, + storage__pb2.DeleteDatasetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DeleteData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/DeleteData', + storage__pb2.DeleteDataRequest.SerializeToString, + storage__pb2.DeleteDataResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/modyn/storage/src/internal/grpc/storage_service_impl.cpp b/modyn/storage/src/internal/grpc/storage_service_impl.cpp index 4640f2593..105f48420 100644 --- a/modyn/storage/src/internal/grpc/storage_service_impl.cpp +++ b/modyn/storage/src/internal/grpc/storage_service_impl.cpp @@ -17,6 +17,11 @@ Status StorageServiceImpl::Get( // NOLINT readability-identifier-naming ServerWriter* writer) { return Get_Impl>(context, request, writer); } +Status StorageServiceImpl::GetNL( // NOLINT readability-identifier-naming + ServerContext* context, const modyn::storage::GetRequest* request, + ServerWriter* writer) { + return Get_Impl_NL>(context, request, writer); +} Status StorageServiceImpl::GetNewDataSince( // NOLINT readability-identifier-naming ServerContext* context, const modyn::storage::GetNewDataSinceRequest* request, diff --git a/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py b/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py index 3f38d0a7e..e3790451c 100644 --- a/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py +++ b/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py @@ -30,7 +30,6 @@ def __init__(self, supervisor: Supervisor, modyn_config: dict) -> None: def start_pipeline(self, request: StartPipelineRequest, context: grpc.ServicerContext) -> PipelineResponse: tid = threading.get_native_id() pid = os.getpid() - logger.info(f"[{pid}][{tid}]: Starting pipeline with request - {request}") start_replay_at: int | None = None @@ -44,7 +43,6 @@ def start_pipeline(self, request: StartPipelineRequest, context: grpc.ServicerCo maximum_triggers = request.maximum_triggers pipeline_config = json.loads(request.pipeline_config.value) - msg = self._supervisor.start_pipeline( pipeline_config, self.modyn_config["supervisor"]["eval_directory"], diff --git a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py index 75333d46f..e4b641135 100644 --- a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py +++ b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py @@ -84,6 +84,7 @@ def register_tracking_info(self, tracking_dfs: dict[str, pd.DataFrame], dataset_ tracking_dfs: A dictionary of dataframes containing tracking information. dataset_end_time: Timestamp of the last sample in the dataset. """ + print(PipelineStage.STORE_TRAINED_MODEL.name) assert tracking_dfs.get(PipelineStage.HANDLE_SINGLE_TRIGGER.name) is not None assert tracking_dfs.get(PipelineStage.STORE_TRAINED_MODEL.name) is not None self.context = AfterPipelineEvalContext(tracking_dfs=tracking_dfs, dataset_end_time=dataset_end_time) diff --git a/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py b/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py index f3f89bc41..dec828a3d 100644 --- a/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py +++ b/modyn/tests/trainer_server/internal/grpc/test_trainer_server_grpc_servicer.py @@ -157,6 +157,7 @@ def get_start_training_request(checkpoint_path=""): pretrained_model_id=-1, lr_scheduler=JsonString(value=json.dumps({})), grad_scaler_configuration=JsonString(value=json.dumps({})), + generative=False ) diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 48856aaa8..c87c41165 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -12,7 +12,15 @@ from typing import Any, cast import grpc -from tenacity import Retrying, after_log, before_log, retry, stop_after_attempt, wait_random_exponential +from tenacity import ( + Retrying, + after_log, + before_log, + retry, + stop_after_attempt, + wait_random_exponential, +) + from torch.utils.data import IterableDataset, get_worker_info from torchvision import transforms @@ -20,9 +28,13 @@ from modyn.storage.internal.grpc.generated.storage_pb2 import ( # pylint: disable=no-name-in-module GetRequest, GetResponse, + GetResponseNoLabels, ) from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub -from modyn.trainer_server.internal.dataset.key_sources import AbstractKeySource, SelectorKeySource +from modyn.trainer_server.internal.dataset.key_sources import ( + AbstractKeySource, + SelectorKeySource, +) from modyn.utils import ( BYTES_PARSER_FUNC_NAME, deserialize_function, @@ -53,7 +65,9 @@ def __init__( shuffle: bool, tokenizer: str | None, log_path: pathlib.Path | None, + generative: bool = False, ): + self._generative = generative self._pipeline_id = pipeline_id self._trigger_id = trigger_id self._training_id = training_id @@ -86,7 +100,9 @@ def __init__( self._pref_started: dict[int, bool] = {} self._thread_data_container: dict[int, dict[str, Any]] = {} self._partition_locks: dict[int, threading.Lock] = {} - self._partition_signals: dict[int, threading.Condition] = {} # Should use the lock out of partition_locks + self._partition_signals: dict[int, threading.Condition] = ( + {} + ) # Should use the lock out of partition_locks self._partition_valid_until: dict[int, int] = {} self._partition_valid: dict[int, bool] = {} self._next_partition_to_fetch = 0 @@ -96,13 +112,16 @@ def __init__( self._shuffled_partition_indices: list[int] = [] if log_path is None: - logger.warning("Did not provide log path for OnlineDataset - logging disabled.") + logger.warning( + "Did not provide log path for OnlineDataset - logging disabled." + ) # tokenizer for NLP tasks self._tokenizer = None self._tokenizer_name = tokenizer if tokenizer is not None: - self._tokenizer = instantiate_class("modyn.models.tokenizers", tokenizer) + self._tokenizer = instantiate_class( + "modyn.models.tokenizers", tokenizer) logger.debug("Initialized OnlineDataset.") @@ -124,7 +143,9 @@ def _setup_composed_transform(self) -> None: self._transform = transforms.Compose(self._transform_list) def _init_transforms(self) -> None: - self._bytes_parser_function = deserialize_function(self._bytes_parser, BYTES_PARSER_FUNC_NAME) + self._bytes_parser_function = deserialize_function( + self._bytes_parser, BYTES_PARSER_FUNC_NAME + ) self._transform = self._bytes_parser_function self._setup_composed_transform() @@ -136,7 +157,9 @@ def _init_transforms(self) -> None: reraise=True, ) def _init_grpc(self, worker_id: int | None = None) -> None: # pragma: no cover - self._storage_channel = grpc.insecure_channel(self._storage_address, options=grpc_common_config()) + self._storage_channel = grpc.insecure_channel( + self._storage_address, options=grpc_common_config() + ) if grpc_connection_established(self._storage_channel): self._storagestub = StorageStub(self._storage_channel) return @@ -147,29 +170,38 @@ def _init_grpc(self, worker_id: int | None = None) -> None: # pragma: no cover def _silence_pil(self) -> None: # pragma: no cover pil_logger = logging.getLogger("PIL") - pil_logger.setLevel(logging.INFO) # by default, PIL on DEBUG spams the console + # by default, PIL on DEBUG spams the console + pil_logger.setLevel(logging.INFO) def _info(self, msg: str, worker_id: int | None) -> None: # pragma: no cover - logger.info(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") + logger.info( + f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}" + ) def _debug(self, msg: str, worker_id: int | None) -> None: # pragma: no cover - logger.debug(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") + logger.debug( + f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}" + ) def _get_data_from_storage( self, selector_keys: list[int], worker_id: int | None = None ) -> Iterator[tuple[list[int], list[bytes], list[int], int]]: processed_keys: set[int] | list[int] = [] has_failed = False - for attempt in Retrying( - stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + reraise=True, ): with attempt: try: - req = GetRequest(dataset_id=self._dataset_id, keys=selector_keys) + new_keys: list[int] + new_samples: list[bytes] + response:GetResponse + req = GetRequest( + dataset_id=self._dataset_id, keys=selector_keys + ) stopw = Stopwatch() - - response: GetResponse stopw.start("ResponseTime", overwrite=True) for response in self._storagestub.Get(req): response_time = stopw.stop("ResponseTime") @@ -177,32 +209,115 @@ def _get_data_from_storage( if not has_failed: assert isinstance(processed_keys, list) processed_keys.extend(keys) - yield keys, list(response.samples), list(response.labels), response_time + yield keys, list(response.samples), list( + response.labels + ), response_time else: # If we have failed, we need to filter out yielded samples # Note that the returned order by storage is non-deterministic assert isinstance(processed_keys, set) - new_keys: list[int] = [key for key in keys if key not in processed_keys] - new_samples: list[bytes] = [ - sample for key, sample in zip(keys, response.samples) if key not in processed_keys + new_keys = [ + key for key in keys if key not in processed_keys + ] + new_samples = [ + sample + for key, sample in zip(keys, response.samples) + if key not in processed_keys ] new_labels: list[int] = [ - label for key, label in zip(keys, response.labels) if key not in processed_keys + label + for key, label in zip(keys, response.labels) + if key not in processed_keys ] processed_keys.update(keys) yield new_keys, new_samples, new_labels, response_time stopw.start("ResponseTime", overwrite=True) + except ( + grpc.RpcError + ) as e: # We catch and reraise to reconnect to rpc and do logging + has_failed = True + # Convert processed keys to set on first failure + processed_keys = ( + set(processed_keys) + if isinstance(processed_keys, list) + else processed_keys + ) + self._info( + "gRPC error occurred, processed_keys = " + + f"{processed_keys}\n{e.code()} - {e.details()}", + worker_id, + ) + self._info(f"Stringified exception: {str(e)}", worker_id) + self._info( + f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", + worker_id, + ) + self._init_grpc(worker_id=worker_id) + raise e + + def _get_data_from_storage_gen( + self, selector_keys: list[int], worker_id: int | None = None + ) -> Iterator[tuple[list[int], list[bytes], int]]: + processed_keys: set[int] | list[int] = [] + has_failed = False + for attempt in Retrying( + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + reraise=True, + ): + with attempt: + try: + new_keys: list[int] + new_samples: list[bytes] + response: GetResponseNoLabels + req = GetRequest( + dataset_id=self._dataset_id, keys=selector_keys + ) + stopw = Stopwatch() + stopw.start("ResponseTime", overwrite=True) + for response in self._storagestub.GetNL(req): + response_time = stopw.stop("ResponseTime") + keys = list(response.keys) + if not has_failed: + assert isinstance(processed_keys, list) + processed_keys.extend(keys) + yield keys, list(response.samples), response_time + else: # If we have failed, we need to filter out yielded samples + # Note that the returned order by storage is non-deterministic + assert isinstance(processed_keys, set) + new_keys = [ + key for key in keys if key not in processed_keys + ] + new_samples = [ + sample + for key, sample in zip(keys, response.samples) + if key not in processed_keys + ] - except grpc.RpcError as e: # We catch and reraise to reconnect to rpc and do logging + processed_keys.update(keys) + yield new_keys, new_samples, response_time + + stopw.start("ResponseTime", overwrite=True) + except ( + grpc.RpcError + ) as e: # We catch and reraise to reconnect to rpc and do logging has_failed = True # Convert processed keys to set on first failure - processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys + processed_keys = ( + set(processed_keys) + if isinstance(processed_keys, list) + else processed_keys + ) self._info( - "gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}", + "gRPC error occurred, processed_keys = " + + f"{processed_keys}\n{e.code()} - {e.details()}", worker_id, ) self._info(f"Stringified exception: {str(e)}", worker_id) - self._info(f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", worker_id) + self._info( + f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", + worker_id, + ) self._init_grpc(worker_id=worker_id) raise e @@ -223,37 +338,77 @@ def _get_data( get_data_log = {} self._sw.start(f"GetKeysAndWeightsPart{partition_id}", overwrite=True) keys, weights = self._key_source.get_keys_and_weights( - worker_id, shuffled_partition_id if shuffled_partition_id is not None else partition_id + worker_id, + ( + shuffled_partition_id + if shuffled_partition_id is not None + else partition_id + ), + ) + get_data_log["get_keys_and_weights"] = self._sw.stop( + f"GetKeysAndWeightsPart{partition_id}" ) - get_data_log["get_keys_and_weights"] = self._sw.stop(f"GetKeysAndWeightsPart{partition_id}") get_data_log["num_items"] = len(keys) self._info("Getting data from storage", worker_id) self._sw.start(f"GetDataPart{partition_id}", overwrite=True) all_response_times = [] + key_weight_map = ( + {key: weights[idx] for idx, key in enumerate(keys)} + if weights is not None + else None + ) + if self._generative: + for data_tuple_gen in self._get_data_from_storage_gen(keys, worker_id=worker_id): + stor_keys, data, response_time = data_tuple_gen + + all_response_times.append(response_time) + num_items = len(stor_keys) + with ( + partition_locks[partition_id] + if partition_locks is not None + else contextlib.suppress() + ): + data_container["data"].extend(data) + data_container["keys"].extend(stor_keys) + data_container["weights"].extend( + [cast(float | None, key_weight_map[key]) + for key in stor_keys] + if key_weight_map is not None + else [None for _ in range(len(stor_keys))] + ) + if partition_valid_until is not None: + partition_valid_until[partition_id] += num_items - key_weight_map = {key: weights[idx] for idx, key in enumerate(keys)} if weights is not None else None - - for data_tuple in self._get_data_from_storage(keys, worker_id=worker_id): - stor_keys, data, labels, response_time = data_tuple - - all_response_times.append(response_time) - num_items = len(stor_keys) - with partition_locks[partition_id] if partition_locks is not None else contextlib.suppress(): - data_container["data"].extend(data) - data_container["keys"].extend(stor_keys) - data_container["labels"].extend(labels) - data_container["weights"].extend( - [cast(float | None, key_weight_map[key]) for key in stor_keys] - if key_weight_map is not None - else [None for _ in range(len(stor_keys))] - ) - if partition_valid_until is not None: - partition_valid_until[partition_id] += num_items + if partition_signals is not None: + with partition_signals[partition_id]: + partition_signals[partition_id].notify_all() + else: + for data_tuple in self._get_data_from_storage(keys, worker_id=worker_id): + stor_keys, data, labels, response_time = data_tuple + + all_response_times.append(response_time) + num_items = len(stor_keys) + with ( + partition_locks[partition_id] + if partition_locks is not None + else contextlib.suppress() + ): + data_container["data"].extend(data) + data_container["keys"].extend(stor_keys) + data_container["labels"].extend(labels) + data_container["weights"].extend( + [cast(float | None, key_weight_map[key]) + for key in stor_keys] + if key_weight_map is not None + else [None for _ in range(len(stor_keys))] + ) + if partition_valid_until is not None: + partition_valid_until[partition_id] += num_items - if partition_signals is not None: - with partition_signals[partition_id]: - partition_signals[partition_id].notify_all() + if partition_signals is not None: + with partition_signals[partition_id]: + partition_signals[partition_id].notify_all() get_data_log["get_data"] = self._sw.stop(f"GetDataPart{partition_id}") get_data_log["response_times"] = all_response_times @@ -281,6 +436,18 @@ def _get_transformed_data_tuple( return key, transformed_sample, label, weight return key, transformed_sample, label + def _get_transformed_data_tuple_gen( + self, key: int, sample: memoryview, weight: float | None + ) -> tuple | None: + assert self._uses_weights is not None + self._sw.start("transform", resume=True) + # mypy complains here because _transform has unknown type, which is ok + transformed_sample = self._transform(sample) # type: ignore + self._sw.stop("transform") + if self._uses_weights: + return key, transformed_sample, weight + return key, transformed_sample + def end_of_trigger_cleaning(self) -> None: self._key_source.end_of_trigger_cleaning() @@ -297,14 +464,22 @@ def _persist_log(self, worker_id: int) -> None: log_file = f"{self._log_path / str(worker_id)}.log" self._log["transform"] = self._sw.measurements.get("transform", 0) - self._log["wait_for_later_partitions"] = self._sw.measurements.get("wait_for_later_partitions", 0) - self._log["wait_for_initial_partition"] = self._sw.measurements.get("wait_for_initial_partition", 0) + self._log["wait_for_later_partitions"] = self._sw.measurements.get( + "wait_for_later_partitions", 0 + ) + self._log["wait_for_initial_partition"] = self._sw.measurements.get( + "wait_for_initial_partition", 0 + ) with open(log_file, "w", encoding="utf-8") as logfile: json.dump(self._log, logfile) def _clear_partition(self, partition_id: int) -> None: - with self._partition_locks[partition_id] if self._partition_locks is not None else contextlib.suppress(): + with ( + self._partition_locks[partition_id] + if self._partition_locks is not None + else contextlib.suppress() + ): self._partition_valid[partition_id] = False self._partition_valid_until[partition_id] = -1 del self._thread_data_container[partition_id] @@ -312,35 +487,51 @@ def _clear_partition(self, partition_id: int) -> None: if "PYTEST_CURRENT_TEST" not in os.environ: gc.collect() - def _prefetch_partition(self, worker_id: int, maybe_continue: bool = False) -> None: + def _prefetch_partition( + self, worker_id: int, maybe_continue: bool = False + ) -> None: assert self._start_prefetch_lock is not None with self._start_prefetch_lock: - if self._num_prefetched_partitions < 1 or self._next_partition_to_fetch >= self._num_partitions: + if ( + self._num_prefetched_partitions < 1 + or self._next_partition_to_fetch >= self._num_partitions + ): return # Prefetching disabled or nothing more to prefetch - if maybe_continue and self._launched_prefetches >= self._num_prefetched_partitions: + if ( + maybe_continue + and self._launched_prefetches >= self._num_prefetched_partitions + ): return # Two callbacks started to prefetch basically at the same time if maybe_continue: # Do this as early as possible to avoid running into the "problem" above frequently self._launched_prefetches += 1 - assert self._next_partition_to_fetch >= 0 assert ( self._next_partition_to_fetch not in self._data_threads ), f"Prefetching for partition {self._next_partition_to_fetch} has already been started" - - self._thread_data_container[self._next_partition_to_fetch] = { - "data": [], - "keys": [], - "labels": [], - "weights": [], - } + if not self._generative: + self._thread_data_container[self._next_partition_to_fetch] = { + "data": [], + "keys": [], + "labels": [], + "weights": [], + } + else: + self._thread_data_container[self._next_partition_to_fetch] = { + "data": [], + "keys": [], + "weights": [], + } self._partition_valid[self._next_partition_to_fetch] = False self._partition_valid_until[self._next_partition_to_fetch] = -1 - self._partition_locks[self._next_partition_to_fetch] = threading.Lock() - self._partition_signals[self._next_partition_to_fetch] = threading.Condition( - self._partition_locks[self._next_partition_to_fetch] + self._partition_locks[self._next_partition_to_fetch] = threading.Lock( + ) + self._partition_signals[self._next_partition_to_fetch] = ( + threading.Condition( + self._partition_locks[self._next_partition_to_fetch] + ) ) callback = None @@ -367,7 +558,9 @@ def callback_func() -> None: # We implement shuffling on a partition level by mapping everything to increasing indices but actually load # different partition data. shuffle_partition_id = ( - self._shuffled_partition_indices[self._next_partition_to_fetch] if self._shuffle else None + self._shuffled_partition_indices[self._next_partition_to_fetch] + if self._shuffle + else None ) self._data_threads[self._next_partition_to_fetch] = threading.Thread( target=self._get_data, @@ -393,13 +586,37 @@ def _fetch_partition_noprefetch( self, worker_id: int, partition_id: int ) -> Iterator[tuple[int, memoryview, int, float | None]]: assert self._num_prefetched_partitions < 1 - container: dict[str, Any] = {"data": [], "keys": [], "labels": [], "weights": []} - shuffle_partition_id = self._shuffled_partition_indices[partition_id] if self._shuffle else None - self._get_data(container, worker_id, partition_id, None, None, None, None, None, shuffle_partition_id) - assert "data" in container and "labels" in container and "keys" in container and "weights" in container - + container: dict[str, Any] = { + "data": [], + "keys": [], + "labels": [], + "weights": [], + } + shuffle_partition_id = ( + self._shuffled_partition_indices[partition_id] if self._shuffle else None + ) + self._get_data( + container, + worker_id, + partition_id, + None, + None, + None, + None, + None, + shuffle_partition_id, + ) + assert ( + "data" in container + and "labels" in container + and "keys" in container + and "weights" in container + ) if self._shuffle: - self._shuffle_partition(partition_id, worker_id, container=container) + if not self._generative: + self._shuffle_partition(partition_id, worker_id,container=container) + else: + self._shuffle_partition_gen(partition_id, worker_id,container=container) for idx in range(len(container["keys"])): yield ( @@ -409,8 +626,45 @@ def _fetch_partition_noprefetch( container["weights"][idx], ) + def _fetch_partition_noprefetch_generative( # based on the one above + self, worker_id: int, partition_id: int + ) -> Iterator[tuple[int, memoryview, float | None]]: + assert self._num_prefetched_partitions < 1 + container: dict[str, Any] = {"data": [], "keys": [], "weights": []} + shuffle_partition_id = ( + self._shuffled_partition_indices[partition_id] if self._shuffle else None + ) + self._get_data( + container, + worker_id, + partition_id, + None, + None, + None, + None, + None, + shuffle_partition_id, + ) + assert "data" in container and "keys" in container and "weights" in container + + if self._shuffle: + if not self._generative: + self._shuffle_partition(partition_id, worker_id) + else: + self._shuffle_partition_gen(partition_id, worker_id) + + for idx in range(len(container["keys"])): + yield ( + container["keys"][idx], + memoryview(container["data"][idx]), + container["weights"][idx], + ) + def _is_partition_fetched(self, partition_id: int) -> bool: - if partition_id not in self._partition_locks or partition_id not in self._partition_valid: + if ( + partition_id not in self._partition_locks + or partition_id not in self._partition_valid + ): return False with self._partition_locks[partition_id]: @@ -422,24 +676,43 @@ def _partition_max_index(self, partition_id: int) -> int: def _get_partition_data( self, last_idx: int, max_idx: int, partition_id: int - ) -> Iterator[tuple[int, memoryview, int, float | None]]: - for idx in range(last_idx + 1, max_idx + 1): - yield ( - self._thread_data_container[partition_id]["keys"][idx], - memoryview(self._thread_data_container[partition_id]["data"][idx]), - self._thread_data_container[partition_id]["labels"][idx], - self._thread_data_container[partition_id]["weights"][idx], - ) + ) -> Iterator[tuple[int, memoryview, int, float | None]]| Iterator[tuple[int, memoryview, float | None]]: + if self._generative: + for idx in range(last_idx + 1, max_idx + 1): + yield ( + self._thread_data_container[partition_id]["keys"][idx], + memoryview( + self._thread_data_container[partition_id]["data"][idx]), + self._thread_data_container[partition_id]["weights"][idx], + ) + else: + for idx in range(last_idx + 1, max_idx + 1): + yield ( + self._thread_data_container[partition_id]["keys"][idx], + memoryview( + self._thread_data_container[partition_id]["data"][idx]), + self._thread_data_container[partition_id]["labels"][idx], + self._thread_data_container[partition_id]["weights"][idx], + ) def _wait_for_new_partition_data(self, partition_id: int) -> None: with self._partition_signals[partition_id]: - self._partition_signals[partition_id].wait(1) # In case we do not get woken up, we at most waste a second + self._partition_signals[partition_id].wait( + 1 + ) # In case we do not get woken up, we at most waste a second - def _shuffle_partition(self, partition_id: int, worker_id: int, container: dict | None = None) -> None: - assert container is not None or self._is_partition_fetched(partition_id) + def _shuffle_partition( + self, partition_id: int, worker_id: int, container: dict | None = None + ) -> None: + assert container is not None or self._is_partition_fetched( + partition_id) assert self._shuffle - container = container if container is not None else self._thread_data_container[partition_id] + container = ( + container + if container is not None + else self._thread_data_container[partition_id] + ) self._info(f"Shuffling partition {partition_id}", worker_id) @@ -459,9 +732,40 @@ def _shuffle_partition(self, partition_id: int, worker_id: int, container: dict self._info(f"Shuffled partition {partition_id}", worker_id) + def _shuffle_partition_gen( + self, partition_id: int, worker_id: int, container: dict | None = None + ) -> None: + assert container is not None or self._is_partition_fetched( + partition_id) + assert self._shuffle + + container = ( + container + if container is not None + else self._thread_data_container[partition_id] + ) + + self._info(f"Shuffling partition {partition_id}", worker_id) + + data_length = len(container["data"]) + indices = list(range(data_length)) + random.shuffle(indices) + + new_data = [container["data"][i] for i in indices] + new_keys = [container["keys"][i] for i in indices] + + new_weights = [container["weights"][i] for i in indices] + + container["data"] = new_data + container["keys"] = new_keys + + container["weights"] = new_weights + + self._info(f"Shuffled partition {partition_id}", worker_id) + def prefetched_partition_generator( self, worker_id: int, partition_id: int - ) -> Iterator[tuple[int, memoryview, int, float | None]]: + ) -> Iterator[tuple[int, memoryview, int, float | None]]|Iterator[tuple[int, memoryview, float | None]]: last_idx = -1 if not self._shuffle: # If we do not shuffle, we can emit data as soon as it streamed over @@ -470,14 +774,17 @@ def prefetched_partition_generator( max_idx = self._partition_max_index(partition_id) if max_idx <= last_idx: # No new data self._wait_for_new_partition_data(partition_id) - yield from self._get_partition_data(last_idx, max_idx, partition_id) last_idx = max_idx else: while not self._is_partition_fetched(partition_id): self._wait_for_new_partition_data(partition_id) + if not self._generative: + + self._shuffle_partition(partition_id, worker_id) + else: - self._shuffle_partition(partition_id, worker_id) + self._shuffle_partition_gen(partition_id, worker_id) # Yield potential remaining data (when not shuffling) or all data (when shuffling) self._info(f"Joining thread for partition {partition_id}", worker_id) @@ -504,21 +811,28 @@ def start_prefetching(self, worker_id: int) -> None: for _ in range(self._parallel_prefetch_requests): self._prefetch_partition(worker_id, True) - def all_partition_generator(self, worker_id: int) -> Iterator[tuple[int, memoryview, int, float | None]]: + def all_partition_generator( + self, worker_id: int + ) -> Iterator[tuple[int, memoryview, int, float | None]]|Iterator[tuple[int, memoryview,float | None]]: self.start_prefetching(worker_id) for partition_id in range(self._num_partitions): self._persist_log(worker_id) - if self._num_prefetched_partitions > 0: if partition_id < self._num_partitions - 1: # As we consume one partition, prefetch exactly one more partition self._prefetch_partition(worker_id, False) yield from self.prefetched_partition_generator(worker_id, partition_id) + else: - yield from self._fetch_partition_noprefetch(worker_id, partition_id) + if not self._generative: + yield from self._fetch_partition_noprefetch(worker_id, partition_id) + else: + yield from self._fetch_partition_noprefetch_generative( + worker_id, partition_id + ) # pylint: disable=too-many-locals, too-many-branches, too-many-statements def __iter__(self) -> Generator: @@ -531,7 +845,9 @@ def __iter__(self) -> Generator: if self._first_call: self._first_call = False - self._debug("This is the first run of iter, making gRPC connections.", worker_id) + self._debug( + "This is the first run of iter, making gRPC connections.", worker_id + ) # We have to initialize transformations and gRPC connections here to do it per dataloader worker, # otherwise the transformations/gRPC connections cannot be pickled for the new processes. self._init_transforms() @@ -560,9 +876,13 @@ def __iter__(self) -> Generator: self._num_partitions = self._key_source.get_num_data_partitions() if self._shuffle: - self._shuffled_partition_indices = list(range(0, self._num_partitions)) + self._shuffled_partition_indices = list( + range(0, self._num_partitions)) random.shuffle(self._shuffled_partition_indices) - self._info(f"Shuffled partitions into random order: {self._shuffled_partition_indices}", worker_id) + self._info( + f"Shuffled partitions into random order: {self._shuffled_partition_indices}", + worker_id, + ) self._info( f"Total number of partitions will be {self._num_partitions}.\n" @@ -573,10 +893,22 @@ def __iter__(self) -> Generator: assert self._log_lock is not None with self._log_lock: self._log["num_partitions"] = self._num_partitions - self._num_prefetched_partitions = min(self._num_prefetched_partitions, self._num_partitions) - - for data_tuple in self.all_partition_generator(worker_id): - if (transformed_tuple := self._get_transformed_data_tuple(*data_tuple)) is not None: - yield transformed_tuple + self._num_prefetched_partitions = min( + self._num_prefetched_partitions, self._num_partitions + ) + if self._generative: + for data_tuple in self.all_partition_generator(worker_id): + if ( + transformed_tuple := self._get_transformed_data_tuple_gen( + *data_tuple + ) + ) is not None: + yield transformed_tuple + else: + for data_tuple in self.all_partition_generator(worker_id): + if ( + transformed_tuple := self._get_transformed_data_tuple(*data_tuple) + ) is not None: + yield transformed_tuple self._persist_log(worker_id) diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py index c3a21053b..b3fedcdb5 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py @@ -1,56 +1,53 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: trainer_server.proto -# Protobuf Python Version: 5.26.1 """Generated protocol buffer code.""" - +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x14trainer_server.proto\x12\x07trainer"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05"\x19\n\x17TrainerAvailableRequest"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t"\xbb\x07\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x42\x07\n\x05_seedB\x0c\n\n_tokenizer"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse"\x00\x62\x06proto3' -) -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "trainer_server_pb2", _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals["_JSONSTRING"]._serialized_start = 33 - _globals["_JSONSTRING"]._serialized_end = 60 - _globals["_PYTHONSTRING"]._serialized_start = 62 - _globals["_PYTHONSTRING"]._serialized_end = 91 - _globals["_DATA"]._serialized_start = 93 - _globals["_DATA"]._serialized_end = 144 - _globals["_TRAINERAVAILABLEREQUEST"]._serialized_start = 146 - _globals["_TRAINERAVAILABLEREQUEST"]._serialized_end = 171 - _globals["_TRAINERAVAILABLERESPONSE"]._serialized_start = 173 - _globals["_TRAINERAVAILABLERESPONSE"]._serialized_end = 218 - _globals["_CHECKPOINTINFO"]._serialized_start = 220 - _globals["_CHECKPOINTINFO"]._serialized_end = 290 - _globals["_STARTTRAININGREQUEST"]._serialized_start = 293 - _globals["_STARTTRAININGREQUEST"]._serialized_end = 1248 - _globals["_STARTTRAININGRESPONSE"]._serialized_start = 1250 - _globals["_STARTTRAININGRESPONSE"]._serialized_end = 1320 - _globals["_TRAININGSTATUSREQUEST"]._serialized_start = 1322 - _globals["_TRAININGSTATUSREQUEST"]._serialized_end = 1366 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_start = 1369 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_end = 1791 - _globals["_STOREFINALMODELREQUEST"]._serialized_start = 1793 - _globals["_STOREFINALMODELREQUEST"]._serialized_end = 1838 - _globals["_STOREFINALMODELRESPONSE"]._serialized_start = 1840 - _globals["_STOREFINALMODELRESPONSE"]._serialized_end = 1904 - _globals["_GETLATESTMODELREQUEST"]._serialized_start = 1906 - _globals["_GETLATESTMODELREQUEST"]._serialized_end = 1950 - _globals["_GETLATESTMODELRESPONSE"]._serialized_start = 1952 - _globals["_GETLATESTMODELRESPONSE"]._serialized_end = 2017 - _globals["_TRAINERSERVER"]._serialized_start = 2020 - _globals["_TRAINERSERVER"]._serialized_end = 2477 + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14trainer_server.proto\x12\x07trainer\"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t\"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t\"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05\"\x19\n\x17TrainerAvailableRequest\"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t\"\xcf\x07\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12\"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x12\x12\n\ngenerative\x18\x1c \x01(\x08\x42\x07\n\x05_seedB\x0c\n\n_tokenizer\"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05\",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen\"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05\",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05\"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse\"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse\"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse\"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse\"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse\"\x00\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'trainer_server_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _JSONSTRING._serialized_start=33 + _JSONSTRING._serialized_end=60 + _PYTHONSTRING._serialized_start=62 + _PYTHONSTRING._serialized_end=91 + _DATA._serialized_start=93 + _DATA._serialized_end=144 + _TRAINERAVAILABLEREQUEST._serialized_start=146 + _TRAINERAVAILABLEREQUEST._serialized_end=171 + _TRAINERAVAILABLERESPONSE._serialized_start=173 + _TRAINERAVAILABLERESPONSE._serialized_end=218 + _CHECKPOINTINFO._serialized_start=220 + _CHECKPOINTINFO._serialized_end=290 + _STARTTRAININGREQUEST._serialized_start=293 + _STARTTRAININGREQUEST._serialized_end=1268 + _STARTTRAININGRESPONSE._serialized_start=1270 + _STARTTRAININGRESPONSE._serialized_end=1340 + _TRAININGSTATUSREQUEST._serialized_start=1342 + _TRAININGSTATUSREQUEST._serialized_end=1386 + _TRAININGSTATUSRESPONSE._serialized_start=1389 + _TRAININGSTATUSRESPONSE._serialized_end=1811 + _STOREFINALMODELREQUEST._serialized_start=1813 + _STOREFINALMODELREQUEST._serialized_end=1858 + _STOREFINALMODELRESPONSE._serialized_start=1860 + _STOREFINALMODELRESPONSE._serialized_end=1924 + _GETLATESTMODELREQUEST._serialized_start=1926 + _GETLATESTMODELREQUEST._serialized_end=1970 + _GETLATESTMODELRESPONSE._serialized_start=1972 + _GETLATESTMODELRESPONSE._serialized_end=2037 + _TRAINERSERVER._serialized_start=2040 + _TRAINERSERVER._serialized_end=2497 # @@protoc_insertion_point(module_scope) diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi index 96b20dde5..aeb9fda1a 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi @@ -56,9 +56,7 @@ class Data(google.protobuf.message.Message): dataset_id: builtins.str = ..., num_dataloaders: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["dataset_id", b"dataset_id", "num_dataloaders", b"num_dataloaders"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "num_dataloaders", b"num_dataloaders"]) -> None: ... global___Data = Data @@ -101,19 +99,13 @@ class CheckpointInfo(google.protobuf.message.Message): checkpoint_interval: builtins.int = ..., checkpoint_path: builtins.str = ..., ) -> None: ... - def ClearField( - self, - field_name: typing.Literal[ - "checkpoint_interval", b"checkpoint_interval", "checkpoint_path", b"checkpoint_path" - ], - ) -> None: ... + def ClearField(self, field_name: typing.Literal["checkpoint_interval", b"checkpoint_interval", "checkpoint_path", b"checkpoint_path"]) -> None: ... global___CheckpointInfo = CheckpointInfo @typing.final class StartTrainingRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - PIPELINE_ID_FIELD_NUMBER: builtins.int TRIGGER_ID_FIELD_NUMBER: builtins.int DEVICE_FIELD_NUMBER: builtins.int @@ -141,6 +133,7 @@ class StartTrainingRequest(google.protobuf.message.Message): ENABLE_ACCURATE_GPU_MEASUREMENTS_FIELD_NUMBER: builtins.int RECORD_LOSS_EVERY_FIELD_NUMBER: builtins.int DROP_LAST_BATCH_FIELD_NUMBER: builtins.int + GENERATIVE_FIELD_NUMBER: builtins.int pipeline_id: builtins.int trigger_id: builtins.int device: builtins.str @@ -158,6 +151,7 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool record_loss_every: builtins.int drop_last_batch: builtins.bool + generative: builtins.bool @property def torch_optimizers_configuration(self) -> global___JsonString: ... @property @@ -208,105 +202,14 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool = ..., record_loss_every: builtins.int = ..., drop_last_batch: builtins.bool = ..., + generative: builtins.bool = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_seed", - b"_seed", - "_tokenizer", - b"_tokenizer", - "bytes_parser", - b"bytes_parser", - "checkpoint_info", - b"checkpoint_info", - "criterion_parameters", - b"criterion_parameters", - "data_info", - b"data_info", - "grad_scaler_configuration", - b"grad_scaler_configuration", - "label_transformer", - b"label_transformer", - "lr_scheduler", - b"lr_scheduler", - "seed", - b"seed", - "tokenizer", - b"tokenizer", - "torch_optimizers_configuration", - b"torch_optimizers_configuration", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_seed", - b"_seed", - "_tokenizer", - b"_tokenizer", - "batch_size", - b"batch_size", - "bytes_parser", - b"bytes_parser", - "checkpoint_info", - b"checkpoint_info", - "criterion_parameters", - b"criterion_parameters", - "data_info", - b"data_info", - "device", - b"device", - "drop_last_batch", - b"drop_last_batch", - "enable_accurate_gpu_measurements", - b"enable_accurate_gpu_measurements", - "epochs_per_trigger", - b"epochs_per_trigger", - "grad_scaler_configuration", - b"grad_scaler_configuration", - "label_transformer", - b"label_transformer", - "load_optimizer_state", - b"load_optimizer_state", - "lr_scheduler", - b"lr_scheduler", - "num_prefetched_partitions", - b"num_prefetched_partitions", - "num_samples_to_pass", - b"num_samples_to_pass", - "parallel_prefetch_requests", - b"parallel_prefetch_requests", - "pipeline_id", - b"pipeline_id", - "pretrained_model_id", - b"pretrained_model_id", - "record_loss_every", - b"record_loss_every", - "seed", - b"seed", - "shuffle", - b"shuffle", - "tokenizer", - b"tokenizer", - "torch_criterion", - b"torch_criterion", - "torch_optimizers_configuration", - b"torch_optimizers_configuration", - "transform_list", - b"transform_list", - "trigger_id", - b"trigger_id", - "use_pretrained_model", - b"use_pretrained_model", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "lr_scheduler", b"lr_scheduler", "seed", b"seed", "tokenizer", b"tokenizer", "torch_optimizers_configuration", b"torch_optimizers_configuration"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "batch_size", b"batch_size", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "device", b"device", "drop_last_batch", b"drop_last_batch", "enable_accurate_gpu_measurements", b"enable_accurate_gpu_measurements", "epochs_per_trigger", b"epochs_per_trigger", "generative", b"generative", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "load_optimizer_state", b"load_optimizer_state", "lr_scheduler", b"lr_scheduler", "num_prefetched_partitions", b"num_prefetched_partitions", "num_samples_to_pass", b"num_samples_to_pass", "parallel_prefetch_requests", b"parallel_prefetch_requests", "pipeline_id", b"pipeline_id", "pretrained_model_id", b"pretrained_model_id", "record_loss_every", b"record_loss_every", "seed", b"seed", "shuffle", b"shuffle", "tokenizer", b"tokenizer", "torch_criterion", b"torch_criterion", "torch_optimizers_configuration", b"torch_optimizers_configuration", "transform_list", b"transform_list", "trigger_id", b"trigger_id", "use_pretrained_model", b"use_pretrained_model"]) -> None: ... @typing.overload def WhichOneof(self, oneof_group: typing.Literal["_seed", b"_seed"]) -> typing.Literal["seed"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_tokenizer", b"_tokenizer"] - ) -> typing.Literal["tokenizer"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_tokenizer", b"_tokenizer"]) -> typing.Literal["tokenizer"] | None: ... global___StartTrainingRequest = StartTrainingRequest @@ -324,9 +227,7 @@ class StartTrainingResponse(google.protobuf.message.Message): training_started: builtins.bool = ..., training_id: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["training_id", b"training_id", "training_started", b"training_started"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["training_id", b"training_id", "training_started", b"training_started"]) -> None: ... global___StartTrainingResponse = StartTrainingResponse @@ -387,90 +288,18 @@ class TrainingStatusResponse(google.protobuf.message.Message): downsampling_batches_seen: builtins.int | None = ..., downsampling_samples_seen: builtins.int | None = ..., ) -> None: ... - def HasField( - self, - field_name: typing.Literal[ - "_batches_seen", - b"_batches_seen", - "_downsampling_batches_seen", - b"_downsampling_batches_seen", - "_downsampling_samples_seen", - b"_downsampling_samples_seen", - "_exception", - b"_exception", - "_samples_seen", - b"_samples_seen", - "batches_seen", - b"batches_seen", - "downsampling_batches_seen", - b"downsampling_batches_seen", - "downsampling_samples_seen", - b"downsampling_samples_seen", - "exception", - b"exception", - "log", - b"log", - "samples_seen", - b"samples_seen", - ], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing.Literal[ - "_batches_seen", - b"_batches_seen", - "_downsampling_batches_seen", - b"_downsampling_batches_seen", - "_downsampling_samples_seen", - b"_downsampling_samples_seen", - "_exception", - b"_exception", - "_samples_seen", - b"_samples_seen", - "batches_seen", - b"batches_seen", - "blocked", - b"blocked", - "downsampling_batches_seen", - b"downsampling_batches_seen", - "downsampling_samples_seen", - b"downsampling_samples_seen", - "exception", - b"exception", - "is_running", - b"is_running", - "is_training", - b"is_training", - "log", - b"log", - "samples_seen", - b"samples_seen", - "state_available", - b"state_available", - "valid", - b"valid", - ], - ) -> None: ... + def HasField(self, field_name: typing.Literal["_batches_seen", b"_batches_seen", "_downsampling_batches_seen", b"_downsampling_batches_seen", "_downsampling_samples_seen", b"_downsampling_samples_seen", "_exception", b"_exception", "_samples_seen", b"_samples_seen", "batches_seen", b"batches_seen", "downsampling_batches_seen", b"downsampling_batches_seen", "downsampling_samples_seen", b"downsampling_samples_seen", "exception", b"exception", "log", b"log", "samples_seen", b"samples_seen"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_batches_seen", b"_batches_seen", "_downsampling_batches_seen", b"_downsampling_batches_seen", "_downsampling_samples_seen", b"_downsampling_samples_seen", "_exception", b"_exception", "_samples_seen", b"_samples_seen", "batches_seen", b"batches_seen", "blocked", b"blocked", "downsampling_batches_seen", b"downsampling_batches_seen", "downsampling_samples_seen", b"downsampling_samples_seen", "exception", b"exception", "is_running", b"is_running", "is_training", b"is_training", "log", b"log", "samples_seen", b"samples_seen", "state_available", b"state_available", "valid", b"valid"]) -> None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_batches_seen", b"_batches_seen"] - ) -> typing.Literal["batches_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_batches_seen", b"_batches_seen"]) -> typing.Literal["batches_seen"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_downsampling_batches_seen", b"_downsampling_batches_seen"] - ) -> typing.Literal["downsampling_batches_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_downsampling_batches_seen", b"_downsampling_batches_seen"]) -> typing.Literal["downsampling_batches_seen"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_downsampling_samples_seen", b"_downsampling_samples_seen"] - ) -> typing.Literal["downsampling_samples_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_downsampling_samples_seen", b"_downsampling_samples_seen"]) -> typing.Literal["downsampling_samples_seen"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_exception", b"_exception"] - ) -> typing.Literal["exception"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_exception", b"_exception"]) -> typing.Literal["exception"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing.Literal["_samples_seen", b"_samples_seen"] - ) -> typing.Literal["samples_seen"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_samples_seen", b"_samples_seen"]) -> typing.Literal["samples_seen"] | None: ... global___TrainingStatusResponse = TrainingStatusResponse @@ -503,9 +332,7 @@ class StoreFinalModelResponse(google.protobuf.message.Message): valid_state: builtins.bool = ..., model_id: builtins.int = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["model_id", b"model_id", "valid_state", b"valid_state"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["model_id", b"model_id", "valid_state", b"valid_state"]) -> None: ... global___StoreFinalModelResponse = StoreFinalModelResponse @@ -538,8 +365,6 @@ class GetLatestModelResponse(google.protobuf.message.Message): valid_state: builtins.bool = ..., model_path: builtins.str = ..., ) -> None: ... - def ClearField( - self, field_name: typing.Literal["model_path", b"model_path", "valid_state", b"valid_state"] - ) -> None: ... + def ClearField(self, field_name: typing.Literal["model_path", b"model_path", "valid_state", b"valid_state"]) -> None: ... global___GetLatestModelResponse = GetLatestModelResponse diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py index a1978b37a..b195f51fc 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py @@ -1,36 +1,8 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" - -import warnings - import grpc -import modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 as trainer__server__pb2 - -GRPC_GENERATED_VERSION = "1.63.0" -GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - warnings.warn( - f"The grpc package installed is at version {GRPC_VERSION}," - + f" but the generated code in trainer_server_pb2_grpc.py depends on" - + f" grpcio>={GRPC_GENERATED_VERSION}." - + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" - + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, - ) +import modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 as trainer__server__pb2 class TrainerServerStub(object): """Missing associated documentation comment in .proto file.""" @@ -42,35 +14,30 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.trainer_available = channel.unary_unary( - "/trainer.TrainerServer/trainer_available", - request_serializer=trainer__server__pb2.TrainerAvailableRequest.SerializeToString, - response_deserializer=trainer__server__pb2.TrainerAvailableResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/trainer_available', + request_serializer=trainer__server__pb2.TrainerAvailableRequest.SerializeToString, + response_deserializer=trainer__server__pb2.TrainerAvailableResponse.FromString, + ) self.start_training = channel.unary_unary( - "/trainer.TrainerServer/start_training", - request_serializer=trainer__server__pb2.StartTrainingRequest.SerializeToString, - response_deserializer=trainer__server__pb2.StartTrainingResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/start_training', + request_serializer=trainer__server__pb2.StartTrainingRequest.SerializeToString, + response_deserializer=trainer__server__pb2.StartTrainingResponse.FromString, + ) self.get_training_status = channel.unary_unary( - "/trainer.TrainerServer/get_training_status", - request_serializer=trainer__server__pb2.TrainingStatusRequest.SerializeToString, - response_deserializer=trainer__server__pb2.TrainingStatusResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/get_training_status', + request_serializer=trainer__server__pb2.TrainingStatusRequest.SerializeToString, + response_deserializer=trainer__server__pb2.TrainingStatusResponse.FromString, + ) self.store_final_model = channel.unary_unary( - "/trainer.TrainerServer/store_final_model", - request_serializer=trainer__server__pb2.StoreFinalModelRequest.SerializeToString, - response_deserializer=trainer__server__pb2.StoreFinalModelResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/store_final_model', + request_serializer=trainer__server__pb2.StoreFinalModelRequest.SerializeToString, + response_deserializer=trainer__server__pb2.StoreFinalModelResponse.FromString, + ) self.get_latest_model = channel.unary_unary( - "/trainer.TrainerServer/get_latest_model", - request_serializer=trainer__server__pb2.GetLatestModelRequest.SerializeToString, - response_deserializer=trainer__server__pb2.GetLatestModelResponse.FromString, - _registered_method=True, - ) + '/trainer.TrainerServer/get_latest_model', + request_serializer=trainer__server__pb2.GetLatestModelRequest.SerializeToString, + response_deserializer=trainer__server__pb2.GetLatestModelResponse.FromString, + ) class TrainerServerServicer(object): @@ -79,216 +46,152 @@ class TrainerServerServicer(object): def trainer_available(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def start_training(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def get_training_status(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def store_final_model(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def get_latest_model(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_TrainerServerServicer_to_server(servicer, server): rpc_method_handlers = { - "trainer_available": grpc.unary_unary_rpc_method_handler( - servicer.trainer_available, - request_deserializer=trainer__server__pb2.TrainerAvailableRequest.FromString, - response_serializer=trainer__server__pb2.TrainerAvailableResponse.SerializeToString, - ), - "start_training": grpc.unary_unary_rpc_method_handler( - servicer.start_training, - request_deserializer=trainer__server__pb2.StartTrainingRequest.FromString, - response_serializer=trainer__server__pb2.StartTrainingResponse.SerializeToString, - ), - "get_training_status": grpc.unary_unary_rpc_method_handler( - servicer.get_training_status, - request_deserializer=trainer__server__pb2.TrainingStatusRequest.FromString, - response_serializer=trainer__server__pb2.TrainingStatusResponse.SerializeToString, - ), - "store_final_model": grpc.unary_unary_rpc_method_handler( - servicer.store_final_model, - request_deserializer=trainer__server__pb2.StoreFinalModelRequest.FromString, - response_serializer=trainer__server__pb2.StoreFinalModelResponse.SerializeToString, - ), - "get_latest_model": grpc.unary_unary_rpc_method_handler( - servicer.get_latest_model, - request_deserializer=trainer__server__pb2.GetLatestModelRequest.FromString, - response_serializer=trainer__server__pb2.GetLatestModelResponse.SerializeToString, - ), + 'trainer_available': grpc.unary_unary_rpc_method_handler( + servicer.trainer_available, + request_deserializer=trainer__server__pb2.TrainerAvailableRequest.FromString, + response_serializer=trainer__server__pb2.TrainerAvailableResponse.SerializeToString, + ), + 'start_training': grpc.unary_unary_rpc_method_handler( + servicer.start_training, + request_deserializer=trainer__server__pb2.StartTrainingRequest.FromString, + response_serializer=trainer__server__pb2.StartTrainingResponse.SerializeToString, + ), + 'get_training_status': grpc.unary_unary_rpc_method_handler( + servicer.get_training_status, + request_deserializer=trainer__server__pb2.TrainingStatusRequest.FromString, + response_serializer=trainer__server__pb2.TrainingStatusResponse.SerializeToString, + ), + 'store_final_model': grpc.unary_unary_rpc_method_handler( + servicer.store_final_model, + request_deserializer=trainer__server__pb2.StoreFinalModelRequest.FromString, + response_serializer=trainer__server__pb2.StoreFinalModelResponse.SerializeToString, + ), + 'get_latest_model': grpc.unary_unary_rpc_method_handler( + servicer.get_latest_model, + request_deserializer=trainer__server__pb2.GetLatestModelRequest.FromString, + response_serializer=trainer__server__pb2.GetLatestModelResponse.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler("trainer.TrainerServer", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler( + 'trainer.TrainerServer', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class TrainerServer(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def trainer_available( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def trainer_available(request, target, - "/trainer.TrainerServer/trainer_available", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/trainer_available', trainer__server__pb2.TrainerAvailableRequest.SerializeToString, trainer__server__pb2.TrainerAvailableResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def start_training( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def start_training(request, target, - "/trainer.TrainerServer/start_training", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/start_training', trainer__server__pb2.StartTrainingRequest.SerializeToString, trainer__server__pb2.StartTrainingResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def get_training_status( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def get_training_status(request, target, - "/trainer.TrainerServer/get_training_status", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/get_training_status', trainer__server__pb2.TrainingStatusRequest.SerializeToString, trainer__server__pb2.TrainingStatusResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def store_final_model( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def store_final_model(request, target, - "/trainer.TrainerServer/store_final_model", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/store_final_model', trainer__server__pb2.StoreFinalModelRequest.SerializeToString, trainer__server__pb2.StoreFinalModelResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def get_latest_model( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, + def get_latest_model(request, target, - "/trainer.TrainerServer/get_latest_model", + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/trainer.TrainerServer/get_latest_model', trainer__server__pb2.GetLatestModelRequest.SerializeToString, trainer__server__pb2.GetLatestModelResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True, - ) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 47acce439..3cc6ea2d7 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -85,7 +85,8 @@ def __init__( self.pipeline_id = training_info.pipeline_id self.training_id = training_info.training_id self.trigger_id = training_info.trigger_id - + self._info("Initializing Pytorch Trainer") + self.generative=training_info.generative self.selector_stub = self.connect_to_selector(training_info.selector_address) if training_info.seed is not None: @@ -217,7 +218,6 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._info("Handled OnBegin Callbacks.") self._log["epochs"] = [] - training_loss: list[float] = [] if self.num_samples_to_pass == 0: epoch_num_generator: Iterable[int] = range(self.epochs_per_trigger) @@ -272,7 +272,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches passed_batches += 1 with GPUMeasurement(self._measure_gpu_ops, "PreprocessBatch", self._device, stopw, resume=True): sample_ids, target, data = self.preprocess_batch(batch, stopw) - + # if self.generative: + #target=data.clone() if retrieve_weights_from_dataloader: # model output is a torch.FloatTensor but weights is a torch.DoubleTensor. # We need to cast to do the dot product @@ -297,15 +298,32 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._assert_data_size(self._batch_size, data, sample_ids, target) with GPUMeasurement(self._measure_gpu_ops, "Forward", self._device, stopw, resume=True): - output = self._model.model(data, sample_ids) + if self.generative: + # Pass input data + output = self._model.model(data) + + else: + # Non-generative task: Pass data, and optionally sample_ids if required + output = self._model.model(data, sample_ids=sample_ids) + with GPUMeasurement(self._measure_gpu_ops, "Loss", self._device, stopw, resume=True): - if weighted_optimization: - # weighted gradient descent - assert weights is not None - loss = torch.dot(self._criterion_nored(output, target), weights / weights.sum()) + if self.generative: + # Shift logits and labels for next-token prediction + output = output[..., :-1, :].contiguous() + target = data[..., 1:].contiguous() + # Calculate loss + loss = self._criterion( + output.view(-1, output.size(-1)), + output.view(-1) + ) else: - loss = self._criterion(output, target) + if weighted_optimization: + # Weighted gradient descent + assert weights is not None + loss = torch.dot(self._criterion_nored(output, target), weights / weights.sum()) + else: + loss = self._criterion(output, target) stopw.start("OnBatchBeforeUpdate", resume=True) for _, callback in self._callbacks.items(): @@ -514,19 +532,21 @@ def preprocess_batch( sample_ids = sample_ids.tolist() elif isinstance(sample_ids, tuple): sample_ids = list(sample_ids) - + print(f"{self.generative}") assert isinstance(sample_ids, list), "Cannot parse result from DataLoader" stopw.stop("PreprocSampleIDs") - - stopw.start("LabelTransform", resume=True) - if self._label_transformer_function is not None: - target = self._label_transformer_function(batch[2]) + if self.generative: + target=None else: - target = batch[2] - stopw.stop("LabelTransform") + stopw.start("LabelTransform", resume=True) + if self._label_transformer_function is not None: + target = self._label_transformer_function(batch[2]) + else: + target = batch[2] + stopw.stop("LabelTransform") - with GPUMeasurement(self._measure_gpu_ops, "MoveLabelToGPU", self._device, stopw, resume=True): - target = target.to(self._device) + with GPUMeasurement(self._measure_gpu_ops, "MoveLabelToGPU", self._device, stopw, resume=True): + target = target.to(self._device) with GPUMeasurement(self._measure_gpu_ops, "MoveDataToGPU", self._device, stopw, resume=True): data: torch.Tensor | dict @@ -573,9 +593,7 @@ def downsample_batch( big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() embeddings = self.get_embeddings_if_recorded() self._downsampler.inform_samples(sample_ids, data, big_batch_output, target, embeddings) - self.end_embedding_recorder_if_needed() - # TODO(#218) Persist information on the sample IDs/weights when downsampling is performed selected_indexes, weights = self._downsampler.select_points() selected_data, selected_target = get_tensors_subset(selected_indexes, data, target, sample_ids) @@ -898,32 +916,36 @@ def _iterate_dataloader_and_compute_scores( Args: dataloader: torch.dataloader to get the data previous_batch_number: The batch number returned from the last call to this method. Useful when this - function is called several times to keep track of previous invocations (ex label by label dataloader). We - need to have a total to correctly update the queue and show the progress in the supervisor counter. - previous_number_of_samples: number of samples processed before calling this function. See above for the use. + function is called several times to keep track of previous invocations (e.g., label-by-label dataloader). + previous_number_of_samples: Number of samples processed before calling this function. See above for usage. Returns: Updated number of batches and samples """ number_of_samples = previous_number_of_samples batch_number = previous_batch_number + for batch in dataloader: self.update_queue("DOWNSAMPLING", batch_number, number_of_samples, training_active=False) batch_number += 1 sample_ids, target, data = self.preprocess_batch(batch) + # Handle cases where target is None for generative tasks + if self.generative and target is None: + target = torch.Tensor() number_of_samples += len(sample_ids) - no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode() context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr with context_manager: with torch.autocast(self._device_type, enabled=self._amp): - # compute the scores and accumulate them - model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() + model_output = ( + self._model.model(data) if self._downsampler.forward_required else torch.Tensor() + ) embeddings = self.get_embeddings_if_recorded() + # Inform the downsampler self._downsampler.inform_samples(sample_ids, data, model_output, target, embeddings) - return batch_number, number_of_samples + # ---------------------------------------------------- Logging --------------------------------------------------- # def _info(self, msg: str) -> None: diff --git a/modyn/trainer_server/internal/utils/training_info.py b/modyn/trainer_server/internal/utils/training_info.py index 07a246b35..ac9df472a 100644 --- a/modyn/trainer_server/internal/utils/training_info.py +++ b/modyn/trainer_server/internal/utils/training_info.py @@ -57,7 +57,8 @@ def __init__( self.shuffle = request.shuffle self.enable_accurate_gpu_measurements = request.enable_accurate_gpu_measurements - + self.generative=request.generative + print(f"generetive:{self.generative}") assert ( self.pretrained_model_path or not self.use_pretrained_model ), "Inconsistent pretrained model configuration"