Skip to content

Commit

Permalink
update pipeline registration in test_online_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jenny011 committed Oct 24, 2023
1 parent a6359af commit 1a3816e
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions integrationtests/online_dataset/test_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import time
from typing import Iterable, Tuple

import enlighten
import grpc
import modyn.storage.internal.grpc.generated.storage_pb2 as storage_pb2
import torch
import yaml
from modyn.selector.internal.grpc.generated.selector_pb2 import DataInformRequest, JsonString, RegisterPipelineRequest
from modyn.selector.internal.grpc.generated.selector_pb2 import DataInformRequest, JsonString
from modyn.selector.internal.grpc.generated.selector_pb2_grpc import SelectorStub
from modyn.storage.internal.grpc.generated.storage_pb2 import (
DatasetAvailableRequest,
Expand All @@ -24,6 +25,7 @@
)
from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub
from modyn.trainer_server.internal.dataset.data_utils import prepare_dataloaders
from modyn.supervisor.internal.grpc_handler import GRPCHandler
from modyn.utils import grpc_connection_established
from PIL import Image
from torchvision import transforms
Expand All @@ -50,6 +52,42 @@ def get_modyn_config() -> dict:

return config

def get_grpc_handler() -> GRPCHandler:
progress_mgr = enlighten.get_manager()
status_bar = progress_mgr.status_bar(
status_format="Modyn{fill}Current Task: {demo}{fill}{elapsed}",
color="bold_underline_bright_white_on_lightslategray",
justify=enlighten.Justify.CENTER,
demo="Initializing",
autorefresh=True,
min_delta=0.5,
)
return GRPCHandler(get_modyn_config(), progress_mgr, status_bar)

def get_minimal_pipeline_config(num_workers: int, strategy_config: dict) -> dict:
return {
"pipeline": {"name": "Test"},
"model": {"id": "ResNet18"},
"training": {
"gpus": 1,
"device": "cpu",
"dataloader_workers": num_workers,
"use_previous_model": True,
"initial_model": "random",
"initial_pass": {"activated": False},
"learning_rate": 0.1,
"batch_size": 42,
"optimizers": [
{"name": "default1", "algorithm": "SGD", "source": "PyTorch", "param_groups": [{"module": "model"}]},
],
"optimization_criterion": {"name": "CrossEntropyLoss"},
"checkpointing": {"activated": False},
"selection_strategy": strategy_config,
},
"data": {"dataset_id": "test", "bytes_parser_function": "def bytes_parser_function(x):\n\treturn x"},
"trigger": {"id": "DataAmountTrigger", "trigger_config": {"data_points_for_trigger": 1}},
}


def connect_to_selector_servicer() -> grpc.Channel:
selector_address = get_selector_address()
Expand Down Expand Up @@ -201,11 +239,9 @@ def prepare_selector(num_dataworkers: int, keys: list[int]) -> Tuple[int, int]:
"config": {"limit": -1, "reset_after_trigger": True},
}

pipeline_id = selector.register_pipeline(
RegisterPipelineRequest(
num_workers=max(num_dataworkers, 1), selection_strategy=JsonString(value=json.dumps(strategy_config))
)
).pipeline_id
grpc_handler = get_grpc_handler()
pipeline_config = get_minimal_pipeline_config(max(num_dataworkers, 1), strategy_config)
pipeline_id = grpc_handler.register_pipeline(pipeline_config)

trigger_id = selector.inform_data_and_trigger(
DataInformRequest(
Expand Down

0 comments on commit 1a3816e

Please sign in to comment.