diff --git a/.github/workflows/workflow.yaml b/.github/workflows/workflow.yaml index 3b8120a06..d650a8930 100644 --- a/.github/workflows/workflow.yaml +++ b/.github/workflows/workflow.yaml @@ -120,7 +120,7 @@ jobs: clangFormatVersion: 16 tidy: - timeout-minutes: 30 + timeout-minutes: 60 runs-on: ubuntu-latest env: CLANG_TIDY: clang-tidy-15 @@ -252,8 +252,8 @@ jobs: branch-coverage: ${{steps.run_main_test_with_coverage.outputs.BRANCH_COVERAGE}} steps: - uses: actions/checkout@v2 - #with: TODO(MaxiBoether): add after merge. - # ref: main + with: + ref: main - name: Install clang 17 uses: KyleMayes/install-llvm-action@v1 @@ -317,7 +317,7 @@ jobs: # Checks whether the base container works correctly. dockerized-unittests: - timeout-minutes: 60 + timeout-minutes: 180 runs-on: ubuntu-latest needs: - flake8 @@ -342,7 +342,7 @@ jobs: integrationtests-debug: - timeout-minutes: 90 + timeout-minutes: 180 runs-on: ubuntu-latest needs: - flake8 @@ -361,7 +361,7 @@ jobs: run: bash scripts/run_integrationtests.sh Debug integrationtests-asan: - timeout-minutes: 90 + timeout-minutes: 180 runs-on: ubuntu-latest needs: - flake8 @@ -380,7 +380,7 @@ jobs: run: bash scripts/run_integrationtests.sh Asan integrationtests-tsan: - timeout-minutes: 90 + timeout-minutes: 180 runs-on: ubuntu-latest needs: - flake8 @@ -399,7 +399,7 @@ jobs: run: bash scripts/run_integrationtests.sh Tsan integrationtests-release: - timeout-minutes: 90 + timeout-minutes: 180 runs-on: ubuntu-latest needs: - flake8 diff --git a/.gitignore b/.gitignore index a284a4724..b27b3cde5 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,12 @@ report.html # Pytest creates files that have the name of the local desktop included, so we need to wildcard here .coverage.* +# storage c++ specific +!modyn/storage/lib +!modyn/storage/lib/googletest + +# Unity build files +cmake-build-debug # File that stores whether Modyn has been configured + backup environment .modyn_configured environment.yml.original @@ -67,4 +73,4 @@ plots/ # Unity build files cmake-build-debug/ clang-tidy-build/ -libbuild/ \ No newline at end of file +libbuild/ diff --git a/.pylintrc b/.pylintrc index ff3b34209..574c3eb6b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -53,6 +53,9 @@ ignore-paths=^modyn/trainer_server/internal/grpc/generated/.*$, ^modyn/metadata_processor/internal/grpc/generated/.*$, ^modyn/metadata_database/internal/grpc/generated.*$, ^modyn/storage/internal/grpc/generated/.*$, + ^modyn/build/.*$, + ^modyn/cmake-build-debug/.*$, + ^modyn/libbuild/.*$, ^modyn/model_storage/internal/grpc/generated/.*$, ^modyn/evaluator/internal/grpc/generated/.*$, ^modyn/models/dlrm/cuda_ext/.*$, diff --git a/CMakeLists.txt b/CMakeLists.txt index 44a3563c3..5376e5f7d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,8 @@ set(CMAKE_EXE_LINKER_FLAGS_TSAN "${CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO} -fsani ##### PUBLIC OPTIONS ##### option(MODYN_BUILD_PLAYGROUND "Set ON to build playground" ON) option(MODYN_BUILD_TESTS "Set ON to build tests" ON) -option(MODYN_BUILD_STORAGE "Set ON to build storage components" OFF) # TODO(MaxiBoether): use this flag when merging into storage PR +option(MODYN_BUILD_STORAGE "Set ON to build storage components" OFF) +option(MODYN_TRY_LOCAL_GRPC "Set ON to try using local gRPC installation instead of building from source" ON) option(MODYN_TEST_COVERAGE "Set ON to add test coverage" OFF) #### INTERNAL OPTIONS #### diff --git a/benchmark/criteo_1TB/execute_pipelines.sh b/benchmark/criteo_1TB/execute_pipelines.sh index c26fbbdc2..4da7bbd45 100644 --- a/benchmark/criteo_1TB/execute_pipelines.sh +++ b/benchmark/criteo_1TB/execute_pipelines.sh @@ -10,4 +10,4 @@ for filename in $SCRIPT_DIR/pipelines/*.yml; do EVAL_DIR="$BASEDIR/$BASE" mkdir -p $EVAL_DIR modyn-supervisor --start-replay-at 0 --evaluation-matrix $filename $MODYN_CONFIG_PATH $EVAL_DIR -done \ No newline at end of file +done diff --git a/cmake/dependencies.cmake b/cmake/dependencies.cmake index 0b5e73970..c70558ff9 100644 --- a/cmake/dependencies.cmake +++ b/cmake/dependencies.cmake @@ -1,19 +1,8 @@ include(FetchContent) -# TODO(MaxiBoether): when merging storage, only downloads the new packages if MODYN_BUILD_STORAGE is enabled - # Configure path to modules (for find_package) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PROJECT_SOURCE_DIR}/cmake/modules/") -################### spdlog #################### -message(STATUS "Making spdlog available.") -FetchContent_Declare( - spdlog - GIT_REPOSITORY https://github.com/gabime/spdlog.git - GIT_TAG v1.12.0 -) -FetchContent_MakeAvailable(spdlog) - ################### fmt #################### message(STATUS "Making fmt available.") FetchContent_Declare( @@ -23,6 +12,16 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(fmt) +################### spdlog #################### +message(STATUS "Making spdlog available.") +set(SPDLOG_FMT_EXTERNAL ON) # Otherwise, we run into linking errors since the fmt version used by spdlog does not match. +FetchContent_Declare( + spdlog + GIT_REPOSITORY https://github.com/gabime/spdlog.git + GIT_TAG v1.12.0 +) +FetchContent_MakeAvailable(spdlog) + ################### argparse #################### message(STATUS "Making argparse available.") FetchContent_Declare( @@ -41,3 +40,23 @@ FetchContent_Declare( GIT_TAG v1.14.0 ) FetchContent_MakeAvailable(googletest) + +if (${MODYN_BUILD_STORAGE}) + message(STATUS "Including storage dependencies.") + include(${MODYN_CMAKE_DIR}/storage_dependencies.cmake) +endif () + +################### yaml-cpp #################### +# Technically, yaml-cpp is currently only required by storage +# But we have a test util function requiring this. + +message(STATUS "Making yaml-cpp available.") + +FetchContent_Declare( + yaml-cpp + GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git + GIT_TAG yaml-cpp-0.7.0 +) +FetchContent_MakeAvailable(yaml-cpp) + +target_compile_options(yaml-cpp INTERFACE -Wno-shadow -Wno-pedantic -Wno-deprecated-declarations) diff --git a/cmake/storage_dependencies.cmake b/cmake/storage_dependencies.cmake new file mode 100644 index 000000000..0f9b209ea --- /dev/null +++ b/cmake/storage_dependencies.cmake @@ -0,0 +1,129 @@ +include(FetchContent) +list(APPEND CMAKE_PREFIX_PATH /opt/homebrew/opt/libpq) # for macOS builds + +# Configure path to modules (for find_package) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PROJECT_SOURCE_DIR}/cmake/modules/") + +# Use original download path +message(STATUS "CMAKE_BINARY_DIR = ${CMAKE_BINARY_DIR}.") +message(STATUS "FETCHCONTENT_BASE_DIR = ${FETCHCONTENT_BASE_DIR}.") + +################### libpq++ #################### +find_package(PostgreSQL REQUIRED) # This needs to be installed on the system - cannot do a lightweight CMake install + +################### rapidcsv #################### +message(STATUS "Making rapidcsv available.") + +FetchContent_Declare( + rapidcsv + GIT_REPOSITORY https://github.com/d99kris/rapidcsv.git + GIT_TAG v8.80 +) +FetchContent_MakeAvailable(rapidcsv) + +################### soci #################### +message(STATUS "Making soci available.") + +FetchContent_Declare( + soci + GIT_REPOSITORY https://github.com/SOCI/soci.git + GIT_TAG v4.0.3 +) +set(SOCI_TESTS OFF CACHE BOOL "soci configuration") +set(SOCI_CXX11 ON CACHE BOOL "soci configuration") +set(SOCI_STATIC ON CACHE BOOL "soci configuration") +set(SOCI_SHARED OFF CACHE BOOL "soci configuration") +set(SOCI_EMPTY OFF CACHE BOOL "soci configuration") +set(SOCI_HAVE_BOOST OFF CACHE BOOL "configuration" FORCE) + +FetchContent_GetProperties(soci) +if(NOT soci_POPULATED) + FetchContent_Populate(soci) + add_subdirectory(${soci_SOURCE_DIR} _deps) +endif() + +# Function to help us fix compiler warnings for all soci targets +function(get_all_targets src_dir var) + set(targets) + get_all_targets_recursive(targets ${src_dir}) + set(${var} ${targets} PARENT_SCOPE) +endfunction() + +macro(get_all_targets_recursive targets dir) + get_property(subdirectories DIRECTORY ${dir} PROPERTY SUBDIRECTORIES) + foreach(subdir ${subdirectories}) + get_all_targets_recursive(${targets} ${subdir}) + endforeach() + + get_property(current_targets DIRECTORY ${dir} PROPERTY BUILDSYSTEM_TARGETS) + list(APPEND ${targets} ${current_targets}) +endmacro() + +get_all_targets(${soci_SOURCE_DIR} all_soci_targets) +foreach(_soci_target IN LISTS all_soci_targets) + target_compile_options(${_soci_target} INTERFACE -Wno-shadow -Wno-zero-as-null-pointer-constant -Wno-pedantic -Wno-undef) +endforeach() + + +################### gRPC #################### +set(MODYN_USES_LOCAL_GRPC false) +if(MODYN_TRY_LOCAL_GRPC) + set(protobuf_MODULE_COMPATIBLE true) + find_package(Protobuf CONFIG) + find_package(gRPC CONFIG) + + if (gRPC_FOUND) + message(STATUS "Found gRPC version ${gRPC_VERSION} locally (gRPC_FOUND = ${gRPC_FOUND})!") + if (NOT TARGET gRPC::grpc_cpp_plugin) + message(STATUS "gRPC::grpc_cpp_plugin is not a target, despite finding CMake. Building from source.") + set(MODYN_TRY_LOCAL_GRPC OFF) + else() + if (Protobuf_FOUND) + message(STATUS "Found protobuf!") + include_directories(${PROTOBUF_INCLUDE_DIRS}) + set(MODYN_USES_LOCAL_GRPC true) + if (NOT TARGET grpc_cpp_plugin) + message(STATUS "Since grpc_cpp_plugin was not defined as a target, we define it manually.") + add_executable(grpc_cpp_plugin ALIAS gRPC::grpc_cpp_plugin) + endif() + else() + message(FATAL "Did not find Protobuf, please run cmake in a clean build directory with -DMODYN_TRY_LOCAL_GRPC=Off or install protobuf on your system.") + endif() + endif() + else() + message(STATUS "Did not find gRPC locally, building from source.") + endif() +endif() + +if((NOT MODYN_TRY_LOCAL_GRPC) OR (NOT gRPC_FOUND)) + message(STATUS "Making gRPC available (this may take a while).") + set(gRPC_PROTOBUF_PROVIDER "module" CACHE BOOL "" FORCE) + set(ABSL_ENABLE_INSTALL ON) # https://github.com/protocolbuffers/protobuf/issues/12185 + FetchContent_Declare( + gRPC + GIT_REPOSITORY https://github.com/grpc/grpc + GIT_TAG v1.59.2 # When updating this, make sure to also update the modynbase dockerfile + GIT_SHALLOW TRUE + ) + set(gRPC_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(gRPC_BUILD_CSHARP_EXT OFF CACHE BOOL "" FORCE) + set(ABSL_BUILD_TESTING OFF CACHE BOOL "" FORCE) + + set(FETCHCONTENT_QUIET OFF) + FetchContent_MakeAvailable(gRPC) + set(FETCHCONTENT_QUIET ON) +endif() + +file(DOWNLOAD +https://raw.githubusercontent.com/protocolbuffers/protobuf/v23.1/cmake/protobuf-generate.cmake +${CMAKE_CURRENT_BINARY_DIR}/protobuf-generate.cmake) +include(${CMAKE_CURRENT_BINARY_DIR}/protobuf-generate.cmake) + +if(NOT COMMAND protobuf_generate) + message(FATAL_ERROR "protobuf_generate not available. Potentially there is an error with your local CMake installation. If set, try using -DMODYN_TRY_LOCAL_GRPC=Off.") +else() + message(STATUS "Found protobuf_generate") +endif() + +message(STATUS "Processed gRPC.") + diff --git a/docker/Base/Dockerfile b/docker/Base/Dockerfile index 66cb0da1a..8023e69a5 100644 --- a/docker/Base/Dockerfile +++ b/docker/Base/Dockerfile @@ -1,11 +1,7 @@ FROM modyndependencies -ARG MODYN_BUILDTYPE=Release -ENV MODYN_BUILDTYPE=${MODYN_BUILDTYPE} - # Copy source code into container -ADD . /src -RUN echo "Used buildtype is ${MODYN_BUILDTYPE}" >> /src/.modyn_buildtype +COPY . /src RUN mamba run -n modyn pip install -e /src WORKDIR /src diff --git a/docker/Dependencies/Dockerfile b/docker/Dependencies/Dockerfile index 60fa7c07c..452105cce 100644 --- a/docker/Dependencies/Dockerfile +++ b/docker/Dependencies/Dockerfile @@ -7,6 +7,7 @@ ENV PYTHONUNBUFFERED=1 RUN apt-get update -yq \ && apt-get upgrade -yq \ && apt-get install --no-install-recommends -qy \ + autoconf \ build-essential \ gcc \ g++ \ @@ -18,17 +19,41 @@ RUN apt-get update -yq \ htop \ procps \ libjpeg-dev \ + libpq-dev \ gdb \ libdw-dev \ libelf-dev \ + libtool \ + pkg-config \ cmake \ - && rm -rf /var/lib/apt/lists/* + ca-certificates \ + libpq-dev \ + libsqlite3-dev \ + software-properties-common \ + curl \ + unzip \ + && rm -rf /var/lib/apt/lists/* \ + && gcc --version && g++ --version && cmake --version # Creates a non-root user with an explicit UID and adds permission to access the /app folder # For more info, please refer to https://aka.ms/vscode-docker-python-configure-containers RUN adduser -u 5678 --disabled-password --gecos "" appuser ENV PATH="${PATH}:/home/appuser/.local/bin" +RUN mkdir /src +ARG MODYN_BUILDTYPE=Release +ENV MODYN_BUILDTYPE=$MODYN_BUILDTYPE +ARG MODYN_DEP_BUILDTYPE=Release +ENV MODYN_DEP_BUILDTYPE=$MODYN_DEP_BUILDTYPE +RUN echo "Used buildtype is ${MODYN_BUILDTYPE}" >> /src/.modyn_buildtype +RUN echo "Used dependency buildtype is ${MODYN_DEP_BUILDTYPE}" >> /src/.modyn_dep_buildtype + +# Install gRPC systemwide. When updating the version, make sure to also update the storage_dependencies.cmake file +RUN git clone --recurse-submodules -b v1.59.2 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \ + cd grpc && mkdir -p cmake/build && cd cmake/build && \ + cmake -DgRPC_PROTOBUF_PROVIDER=module -DABSL_ENABLE_INSTALL=On -DgRPC_BUILD_CSHARP_EXT=Off -DABSL_BUILD_TESTING=Off -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=${MODYN_DEP_BUILDTYPE} ../.. && \ + make -j8 && make install && cd ../../ + # Install mamba ENV CONDA_DIR /opt/mamba ENV MAMBA_DIR /opt/mamba @@ -39,4 +64,5 @@ RUN mamba update -n base -c defaults mamba && mamba update --all && mamba init b # Install dependencies COPY ./environment.yml /tmp/environment.yml -RUN mamba env create -f /tmp/environment.yml \ No newline at end of file +RUN mamba env create -f /tmp/environment.yml + diff --git a/docker/Storage/Dockerfile b/docker/Storage/Dockerfile index 836dae0d8..de3cf5e3c 100644 --- a/docker/Storage/Dockerfile +++ b/docker/Storage/Dockerfile @@ -1,6 +1,28 @@ -FROM modynbase:latest +FROM modyndependencies:latest + +COPY ./CMakeLists.txt /src +COPY ./cmake /src/cmake +COPY ./modyn/CMakeLists.txt /src/modyn/CMakeLists.txt +COPY ./modyn/storage /src/modyn/storage +COPY ./modyn/common/CMakeLists.txt /src/modyn/common/CMakeLists.txt +COPY ./modyn/common/cpp /src/modyn/common/cpp +COPY ./modyn/common/example_extension /src/modyn/common/example_extension +COPY ./modyn/common/trigger_sample /src/modyn/common/trigger_sample +COPY ./modyn/protos/storage.proto /src/modyn/protos/storage.proto + +WORKDIR /src +RUN chown -R appuser /src +USER appuser + +RUN mkdir build \ + && cd build \ + && cmake .. -DCMAKE_BUILD_TYPE=${MODYN_BUILDTYPE} -DMODYN_BUILD_TESTS=Off -DMODYN_BUILD_PLAYGROUND=Off -DMODYN_BUILD_STORAGE=On \ + && make -j8 modyn-storage + +# These files are copied after building the storage to avoid rebuilding if the config changes +COPY ./modyn/config /src/modyn/config +COPY ./conf /src/conf -RUN chmod a+x /src/modyn/storage/modyn-storage # During debugging, this entry point will be overridden. For more information, please refer to https://aka.ms/vscode-docker-python-debug -CMD mamba run -n modyn --no-capture-output ./modyn/storage/modyn-storage ./modyn/config/examples/modyn_config.yaml \ No newline at end of file +CMD ./build/modyn/storage/modyn-storage ./modyn/config/examples/modyn_config.yaml \ No newline at end of file diff --git a/docs/TECHNICAL.md b/docs/TECHNICAL.md index 188512b6d..992ef31d6 100644 --- a/docs/TECHNICAL.md +++ b/docs/TECHNICAL.md @@ -43,6 +43,20 @@ In case you want to build extensions or components on your own, you need to crea By default, we only build the extensions to avoid downloading the huge gRPC library. In case you want to build the storage C++ component, enable `-DMODYN_BUILD_STORAGE=On` when running CMake. +Furthermore, by default, we enable the `-DMODYN_TRY_LOCAL_GRPC` flag. +This flag checks whether gRPC is available locally on your system and uses this installation for rapid development, instead of rebuilding gRPC from source everytime like in CI. +In order to install gRPC on your system, you can either use your system's package manager or run the following instructions: + +``` +git clone --recurse-submodules -b v1.59.2 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \ + cd grpc && mkdir -p cmake/build && cd cmake/build && \ + cmake -DgRPC_PROTOBUF_PROVIDER=module -DABSL_ENABLE_INSTALL=On -DgRPC_BUILD_CSHARP_EXT=Off -DABSL_BUILD_TESTING=Off -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=${MODYN_DEP_BUILDTYPE} ../.. && \ + make -j8 && make install && cd ../../ +``` + +Please adjust the version as required. +If you run into problems with the system gRPC installation, set `-DMODYN_TRY_LOCAL_GRPC=Off`. + ### Docker-Compose Setup We use docker-compose to manage the system setup. The `docker-compose.yml` file describes our setup and includes comments explaining it. diff --git a/integrationtests/online_dataset/test_online_dataset.py b/integrationtests/online_dataset/test_online_dataset.py index c7cb87b2b..bddc82c3a 100644 --- a/integrationtests/online_dataset/test_online_dataset.py +++ b/integrationtests/online_dataset/test_online_dataset.py @@ -25,6 +25,7 @@ from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub from modyn.trainer_server.internal.dataset.data_utils import prepare_dataloaders from modyn.utils import grpc_connection_established +from modyn.utils.utils import flatten from PIL import Image from torchvision import transforms @@ -234,20 +235,18 @@ def get_new_data_since(timestamp: int) -> Iterable[GetNewDataSinceResponse]: def get_data_keys() -> list[int]: - response = None keys = [] for i in range(60): responses = list(get_new_data_since(0)) - assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" - if len(responses) == 1: - response = responses[0] - keys = list(response.keys) + keys = [] + if len(responses) > 0: + keys = flatten([list(response.keys) for response in responses]) if len(keys) == 10: break time.sleep(1) - assert response is not None, "Did not get any response from Storage" - assert len(keys) == 10, f"Not all images were returned. Images returned: {response.keys}" + assert len(responses) > 0, "Did not get any response from Storage" + assert len(keys) == 10, f"Not all images were returned. Images returned: {keys}" return keys @@ -378,8 +377,8 @@ def main() -> None: try: test_dataset() finally: - cleanup_dataset_dir() cleanup_storage_database() + cleanup_dataset_dir() if __name__ == "__main__": diff --git a/integrationtests/run.sh b/integrationtests/run.sh index 53e3e4e56..2f8c363f5 100755 --- a/integrationtests/run.sh +++ b/integrationtests/run.sh @@ -1,7 +1,7 @@ #!/bin/bash set -e # stops execution on non zero exit code -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) echo "Integration tests are located in $SCRIPT_DIR" echo "Running as user $USER" @@ -12,10 +12,11 @@ python $SCRIPT_DIR/test_ftp_connections.py echo "Running storage integration tests" python $SCRIPT_DIR/storage/integrationtest_storage.py python $SCRIPT_DIR/storage/integrationtest_storage_csv.py +python $SCRIPT_DIR/storage/integrationtest_storage_binary.py echo "Running selector integration tests" python $SCRIPT_DIR/selector/integrationtest_selector.py echo "Running online datasets integration tests" python $SCRIPT_DIR/online_dataset/test_online_dataset.py echo "Running model storage integration tests" python $SCRIPT_DIR/model_storage/integrationtest_model_storage.py -echo "Successfuly ran all integration tests." \ No newline at end of file +echo "Successfuly ran all integration tests." diff --git a/integrationtests/storage/integrationtest_storage.py b/integrationtests/storage/integrationtest_storage.py index 86693bbbc..4d8639935 100644 --- a/integrationtests/storage/integrationtest_storage.py +++ b/integrationtests/storage/integrationtest_storage.py @@ -1,5 +1,6 @@ import io import json +import math import os import pathlib import random @@ -12,6 +13,7 @@ import yaml from modyn.storage.internal.grpc.generated.storage_pb2 import ( DatasetAvailableRequest, + DeleteDataRequest, GetDataInIntervalRequest, GetDataInIntervalResponse, GetDataPerWorkerRequest, @@ -25,6 +27,7 @@ ) from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub from modyn.utils import grpc_connection_established +from modyn.utils.utils import flatten from PIL import Image SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) @@ -164,7 +167,7 @@ def cleanup_storage_database() -> None: def add_image_to_dataset(image: Image, name: str) -> None: image.save(DATASET_PATH / name) - IMAGE_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000))) + IMAGE_UPDATED_TIME_STAMPS.append(int(math.floor(os.path.getmtime(DATASET_PATH / name)))) def create_random_image() -> Image: @@ -233,6 +236,7 @@ def check_data(keys: list[str], expected_images: list[bytes]) -> None: keys=keys, ) + i = -1 for i, response in enumerate(storage.Get(request)): if len(response.samples) == 0: assert False, f"Could not get image with key {keys[i]}." @@ -245,55 +249,91 @@ def check_data(keys: list[str], expected_images: list[bytes]) -> None: assert i == len(keys) - 1, f"Could not get all images. Images missing: keys: {keys} i: {i}" +def check_delete_data(keys_to_delete: list[int]) -> None: + storage_channel = connect_to_storage() + + storage = StorageStub(storage_channel) + + request = DeleteDataRequest( + dataset_id="test_dataset", + keys=keys_to_delete, + ) + + responses = storage.DeleteData(request) + + assert responses.success, "Could not delete data." + + def test_storage() -> None: check_get_current_timestamp() # Check if the storage service is available. create_dataset_dir() - add_images_to_dataset(0, 10, FIRST_ADDED_IMAGES) # Add images to the dataset. register_new_dataset() check_dataset_availability() # Check if the dataset is available. + check_dataset_size(0) # Check if the dataset is empty. + check_dataset_size_invalid() + add_images_to_dataset(0, 10, FIRST_ADDED_IMAGES) # Add images to the dataset. + response = None for i in range(20): + keys = [] + labels = [] responses = list(get_new_data_since(0)) - assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" - if len(responses) == 1: - response = responses[0] - if len(response.keys) == 10: + if len(responses) > 0: + keys = flatten([list(response.keys) for response in responses]) + labels = flatten([list(response.keys) for response in responses]) + if len(keys) == 10: + assert (label in [f"{i}" for i in range(0, 10)] for label in labels) break time.sleep(1) - assert response is not None, "Did not get any response from Storage" - assert len(response.keys) == 10, f"Not all images were returned. Images returned: {response.keys}" + assert len(responses) > 0, "Did not get any response from Storage" + assert len(keys) == 10, f"Not all images were returned." - check_data(response.keys, FIRST_ADDED_IMAGES) + first_image_keys = keys + + check_data(keys, FIRST_ADDED_IMAGES) check_dataset_size(10) + # Otherwise, if the test runs too quick, the timestamps of the new data equals the timestamps of the old data, and then we have a problem + print("Sleeping for 2 seconds before adding more images to the dataset...") + time.sleep(2) + print("Continuing test.") + add_images_to_dataset(10, 20, SECOND_ADDED_IMAGES) # Add more images to the dataset. for i in range(60): + keys = [] + labels = [] responses = list(get_new_data_since(IMAGE_UPDATED_TIME_STAMPS[9] + 1)) - assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" - if len(responses) == 1: - response = responses[0] - if len(response.keys) == 10: + if len(responses) > 0: + keys = flatten([list(response.keys) for response in responses]) + labels = flatten([list(response.keys) for response in responses]) + if len(keys) == 10: + assert (label in [f"{i}" for i in range(10, 20)] for label in labels) break time.sleep(1) - assert response is not None, "Did not get any response from Storage" - assert len(response.keys) == 10, f"Not all images were returned. Images returned: {response.keys}" + assert len(responses) > 0, "Did not get any response from Storage" + assert len(keys) == 10, f"Not all images were returned. Images returned = {keys}" - check_data(response.keys, SECOND_ADDED_IMAGES) + check_data(keys, SECOND_ADDED_IMAGES) check_dataset_size(20) responses = list(get_data_in_interval(0, IMAGE_UPDATED_TIME_STAMPS[9])) - assert len(responses) == 1, f"Received batched/no response, shouldn't happen: {responses}" - response = responses[0] - check_data(response.keys, FIRST_ADDED_IMAGES) + assert len(responses) > 0, f"Received no response, shouldn't happen: {responses}" + keys = flatten([list(response.keys) for response in responses]) + + check_data(keys, FIRST_ADDED_IMAGES) check_data_per_worker() + check_delete_data(first_image_keys) + + check_dataset_size(10) + check_get_current_timestamp() # Check if the storage service is still available. @@ -301,8 +341,8 @@ def main() -> None: try: test_storage() finally: - cleanup_dataset_dir() cleanup_storage_database() + cleanup_dataset_dir() if __name__ == "__main__": diff --git a/integrationtests/storage/integrationtest_storage_binary.py b/integrationtests/storage/integrationtest_storage_binary.py new file mode 100644 index 000000000..58f6ffb96 --- /dev/null +++ b/integrationtests/storage/integrationtest_storage_binary.py @@ -0,0 +1,185 @@ +############ +# storage integration tests adapted to binary input format. +# Unchanged functions are imported from the original test +# Instead of images, we have binary files. The binary files with random content of size 10 bytes. + +import json +import math +import os +import random +import time +from typing import Tuple + +# unchanged functions are imported from the original test file +from integrationtests.storage.integrationtest_storage import ( + DATASET_PATH, + check_dataset_availability, + check_get_current_timestamp, + cleanup_dataset_dir, + cleanup_storage_database, + connect_to_storage, + create_dataset_dir, + get_data_in_interval, + get_new_data_since, +) +from modyn.storage.internal.grpc.generated.storage_pb2 import GetRequest, RegisterNewDatasetRequest +from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub +from modyn.utils.utils import flatten + +# Because we have no mapping of file to key (happens in the storage service), we have to keep +# track of the samples we added to the dataset ourselves and compare them to the samples we get +# from the storage service. +FIRST_ADDED_BINARY = [] +SECOND_ADDED_BINARY = [] +BINARY_UPDATED_TIME_STAMPS = [] + + +def register_new_dataset() -> None: + storage_channel = connect_to_storage() + + storage = StorageStub(storage_channel) + + request = RegisterNewDatasetRequest( + base_path=str(DATASET_PATH), + dataset_id="test_dataset", + description="Test dataset for integration tests of binary wrapper.", + file_wrapper_config=json.dumps( + { + "file_extension": ".bin", + "label_size": 4, + "record_size": 10, + } + ), + file_wrapper_type="BinaryFileWrapper", + filesystem_wrapper_type="LocalFilesystemWrapper", + version="0.1.0", + ) + + response = storage.RegisterNewDataset(request) + + assert response.success, "Could not register new dataset." + + +def add_file_to_dataset(binary_data: bytes, name: str) -> None: + with open(DATASET_PATH / name, "wb") as f: + f.write(binary_data) + BINARY_UPDATED_TIME_STAMPS.append(int(math.floor(os.path.getmtime(DATASET_PATH / name)))) + + +def create_random_binary_file() -> Tuple[bytes, list[bytes]]: + binary_data = b"" + samples = [] + for i in range(250): + sample_binary_data = random.randbytes(10) + binary_data += sample_binary_data + samples.append(sample_binary_data[4:]) + + return binary_data, samples + + +def add_files_to_dataset( + start_number: int, + end_number: int, + samples: list[bytes], +) -> list[bytes]: + create_dataset_dir() + + for i in range(start_number, end_number): + binary_file, file_samples = create_random_binary_file() + add_file_to_dataset(binary_file, f"binary_{i}.bin") + samples.extend(file_samples) + + return samples + + +def check_data(keys: list[str], expected_samples: list[bytes]) -> None: + storage_channel = connect_to_storage() + + storage = StorageStub(storage_channel) + + request = GetRequest( + dataset_id="test_dataset", + keys=keys, + ) + samples_counter = 0 + for _, response in enumerate(storage.Get(request)): + if len(response.samples) == 0: + assert False, f"Could not get sample with key {keys[samples_counter]}." + for sample in response.samples: + if sample is None: + assert False, f"Could not get sample with key {keys[samples_counter]}." + if sample not in expected_samples: + raise ValueError( + f"Sample {sample} with key {keys[samples_counter]} is not present in the " + f"expected samples {expected_samples}. " + ) + samples_counter += 1 + + assert samples_counter == len( + keys + ), f"Could not get all samples. Samples missing: keys: {sorted(keys)} i: {samples_counter}" + + +def test_storage() -> None: + check_get_current_timestamp() # Check if the storage service is available. + create_dataset_dir() + register_new_dataset() + check_dataset_availability() # Check if the dataset is available. + + add_files_to_dataset(0, 10, FIRST_ADDED_BINARY) # Add samples to the dataset. + + response = None + for i in range(500): + responses = list(get_new_data_since(0)) + keys = [] + if len(responses) > 0: + keys = flatten([list(response.keys) for response in responses]) + if len(keys) == 2500: # 10 files, each one with 250 samples + break + time.sleep(1) + + assert len(responses) > 0, "Did not get any response from Storage" + assert len(keys) == 2500, f"Not all samples were returned. Samples returned: {keys}" + + check_data(keys, FIRST_ADDED_BINARY) + + # Otherwise, if the test runs too quick, the timestamps of the new data equals the timestamps of the old data, and then we have a problem + print("Sleeping for 2 seconds before adding more binary files to the dataset...") + time.sleep(2) + print("Continuing test.") + + add_files_to_dataset(10, 20, SECOND_ADDED_BINARY) # Add more samples to the dataset. + + for i in range(500): + responses = list(get_new_data_since(BINARY_UPDATED_TIME_STAMPS[9] + 1)) + keys = [] + if len(responses) > 0: + keys = flatten([list(response.keys) for response in responses]) + if len(keys) == 2500: + break + time.sleep(1) + + assert len(responses) > 0, "Did not get any response from Storage" + assert len(keys) == 2500, f"Not all samples were returned. Samples returned: {keys}" + + check_data(keys, SECOND_ADDED_BINARY) + + responses = list(get_data_in_interval(0, BINARY_UPDATED_TIME_STAMPS[9])) + assert len(responses) > 0, f"Received no response, shouldn't happen: {responses}" + keys = flatten([list(response.keys) for response in responses]) + + check_data(keys, FIRST_ADDED_BINARY) + + check_get_current_timestamp() # Check if the storage service is still available. + + +def main() -> None: + try: + test_storage() + finally: + cleanup_storage_database() + cleanup_dataset_dir() + + +if __name__ == "__main__": + main() diff --git a/integrationtests/storage/integrationtest_storage_csv.py b/integrationtests/storage/integrationtest_storage_csv.py index 0cdf6679f..0850d2244 100644 --- a/integrationtests/storage/integrationtest_storage_csv.py +++ b/integrationtests/storage/integrationtest_storage_csv.py @@ -6,6 +6,7 @@ # where index is a random number, file is the fileindex and the label (last column) is a global counter import json +import math import os import random import time @@ -25,6 +26,7 @@ ) from modyn.storage.internal.grpc.generated.storage_pb2 import GetRequest, RegisterNewDatasetRequest from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub +from modyn.utils.utils import flatten # Because we have no mapping of file to key (happens in the storage service), we have to keep # track of the samples we added to the dataset ourselves and compare them to the samples we get @@ -57,7 +59,7 @@ def register_new_dataset() -> None: def add_file_to_dataset(csv_file_content: str, name: str) -> None: with open(DATASET_PATH / name, "w") as f: f.write(csv_file_content) - CSV_UPDATED_TIME_STAMPS.append(int(round(os.path.getmtime(DATASET_PATH / name) * 1000))) + CSV_UPDATED_TIME_STAMPS.append(int(math.floor(os.path.getmtime(DATASET_PATH / name)))) def create_random_csv_row(file: int, counter: int) -> str: @@ -118,46 +120,52 @@ def check_data(keys: list[str], expected_samples: list[bytes]) -> None: def test_storage() -> None: check_get_current_timestamp() # Check if the storage service is available. create_dataset_dir() - add_files_to_dataset(0, 10, [], FIRST_ADDED_CSVS) # Add samples to the dataset. register_new_dataset() check_dataset_availability() # Check if the dataset is available. + add_files_to_dataset(0, 10, [], FIRST_ADDED_CSVS) # Add samples to the dataset. + response = None for i in range(500): responses = list(get_new_data_since(0)) - assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" - if len(responses) == 1: - response = responses[0] - if len(response.keys) == 250: # 10 files, each one with 250 samples + keys = [] + if len(responses) > 0: + keys = flatten([list(response.keys) for response in responses]) + if len(keys) == 250: # 10 files, each one with 25 samples break time.sleep(1) - assert response is not None, "Did not get any response from Storage" - assert len(response.keys) == 250, f"Not all samples were returned. Samples returned: {response.keys}" + assert len(responses) > 0, "Did not get any response from Storage" + assert len(keys) == 250, f"Not all samples were returned. Samples returned: {keys}" - check_data(response.keys, FIRST_ADDED_CSVS) + check_data(keys, FIRST_ADDED_CSVS) + + # Otherwise, if the test runs too quick, the timestamps of the new data equals the timestamps of the old data, and then we have a problem + print("Sleeping for 2 seconds before adding more csvs to the dataset...") + time.sleep(2) + print("Continuing test.") add_files_to_dataset(10, 20, [], SECOND_ADDED_CSVS) # Add more samples to the dataset. for i in range(500): responses = list(get_new_data_since(CSV_UPDATED_TIME_STAMPS[9] + 1)) - assert len(responses) < 2, f"Received batched response, shouldn't happen: {responses}" - if len(responses) == 1: - response = responses[0] - if len(response.keys) == 250: + keys = [] + if len(responses) > 0: + keys = flatten([list(response.keys) for response in responses]) + if len(keys) == 250: break time.sleep(1) - assert response is not None, "Did not get any response from Storage" - assert len(response.keys) == 250, f"Not all samples were returned. Samples returned: {response.keys}" + assert len(responses) > 0, "Did not get any response from Storage" + assert len(keys) == 250, f"Not all samples were returned. Samples returned: {keys}" - check_data(response.keys, SECOND_ADDED_CSVS) + check_data(keys, SECOND_ADDED_CSVS) responses = list(get_data_in_interval(0, CSV_UPDATED_TIME_STAMPS[9])) - assert len(responses) == 1, f"Received batched/no response, shouldn't happen: {responses}" - response = responses[0] + assert len(responses) > 0, f"Received no response, shouldn't happen: {responses}" + keys = flatten([list(response.keys) for response in responses]) - check_data(response.keys, FIRST_ADDED_CSVS) + check_data(keys, FIRST_ADDED_CSVS) check_get_current_timestamp() # Check if the storage service is still available. @@ -166,8 +174,8 @@ def main() -> None: try: test_storage() finally: - cleanup_dataset_dir() cleanup_storage_database() + cleanup_dataset_dir() if __name__ == "__main__": diff --git a/modyn/CMakeLists.txt b/modyn/CMakeLists.txt index cd5bb9211..2fde364c1 100644 --- a/modyn/CMakeLists.txt +++ b/modyn/CMakeLists.txt @@ -3,9 +3,10 @@ add_subdirectory("common") ##### MODYN STORAGE BINARY ##### -# TODO(MaxiBoether): add while merging storage PR -#add_executable(modyn-storage storage/src/main.cpp) -#target_link_libraries(modyn-storage PRIVATE modyn) +if (${MODYN_BUILD_STORAGE}) + message(STATUS "Storage is included in this build.") + add_subdirectory(storage) +endif () ##### PLAYGROUND ##### if (${MODYN_BUILD_PLAYGROUND}) diff --git a/modyn/common/cpp/include/modyn/utils/utils.hpp b/modyn/common/cpp/include/modyn/utils/utils.hpp index 55350deff..3e2220b7d 100644 --- a/modyn/common/cpp/include/modyn/utils/utils.hpp +++ b/modyn/common/cpp/include/modyn/utils/utils.hpp @@ -19,6 +19,14 @@ } \ static_assert(true, "End call of macro with a semicolon") +#ifdef NDEBUG +#define DEBUG_ASSERT(expr, msg) \ + do { \ + } while (0) +#else +#define DEBUG_ASSERT(expr, msg) ASSERT((expr), (msg)) +#endif + namespace modyn::utils { bool is_power_of_two(uint64_t value); diff --git a/modyn/config/examples/modyn_config.yaml b/modyn/config/examples/modyn_config.yaml index 140992ba8..a662200c6 100644 --- a/modyn/config/examples/modyn_config.yaml +++ b/modyn/config/examples/modyn_config.yaml @@ -9,7 +9,9 @@ storage: sample_batch_size: 2000000 sample_dbinsertion_batchsize: 1000000 insertion_threads: 8 + retrieval_threads: 8 sample_table_unlogged: true + file_watcher_watchdog_sleep_time_s: 5 datasets: [ { @@ -129,6 +131,22 @@ storage: ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 4096, + }, + { + name: "cloc", + description: "CLOC Dataset", + version: "0.0.1", + base_path: "/datasets/cloc", + filesystem_wrapper_type: "LocalFilesystemWrapper", + file_wrapper_type: "SingleSampleFileWrapper", + file_wrapper_config: + { + file_extension: ".jpg", + label_file_extension: ".label" + }, + ignore_last_timestamp: false, + file_watcher_interval: 999999999, + selector_batch_size: 100000, } ] database: diff --git a/modyn/config/schema/modyn_config_schema.yaml b/modyn/config/schema/modyn_config_schema.yaml index 34fb6d383..d094fc04b 100644 --- a/modyn/config/schema/modyn_config_schema.yaml +++ b/modyn/config/schema/modyn_config_schema.yaml @@ -2,8 +2,7 @@ --- $schema: "http://json-schema.org/draft-04/schema" id: "http://stsci.edu/schemas/yaml-schema/draft-01" -title: - Modyn Configuration +title: Modyn Configuration description: | This is the configuration file for the Modyn. It contains the configuration for the system, adapt as required. @@ -38,7 +37,7 @@ properties: type: number description: | The size of a batch when requesting new samples from storage. All new samples are returned, however, to reduce - the size of a single answer the keys are batched in sizes of `sample_batch_size`. + the size of a single answer the keys are batched in sizes of `sample_batch_size`. Defaults to 10000. sample_dbinsertion_batchsize: type: number description: | @@ -47,6 +46,10 @@ properties: type: number description: | The number of threads used to insert samples into the storage DB. If set to <= 0, multithreaded inserts are disabled. + retrieval_threads: + type: number + description: | + The number of threads used to get samples from the storage DB. If set to <= 1, multithreaded gets are disabled. sample_table_unlogged: type: boolean description: | @@ -56,7 +59,11 @@ properties: force_fallback_insert: type: boolean description: | - When enabled, always use SQLAlchemy insert functionality instead of potentially optimized techniques. + When enabled, always use fallback insert functionality instead of potentially optimized techniques. + file_watcher_watchdog_sleep_time_s: + type: number + description: | + The time in seconds the file watcher watchdog sleeps between checking if the file watchers are still alive. Defaults to 3. datasets: type: array items: @@ -376,4 +383,4 @@ required: - model_storage - metadata_database - selector - - trainer_server \ No newline at end of file + - trainer_server diff --git a/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2.py b/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2.py index 9c4f83822..8be1acacd 100644 --- a/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2.py +++ b/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2.py @@ -12,27 +12,25 @@ _sym_db = _symbol_database.Default() - - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18metadata_processor.proto\x12\x12metadata_processor\"F\n\x17RegisterPipelineRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x16\n\x0eprocessor_type\x18\x02 \x01(\t\"\x12\n\x10PipelineResponse\"\xc4\x01\n\x17TrainingMetadataRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12@\n\x10trigger_metadata\x18\x03 \x01(\x0b\x32&.metadata_processor.PerTriggerMetadata\x12>\n\x0fsample_metadata\x18\x04 \x03(\x0b\x32%.metadata_processor.PerSampleMetadata\"\"\n\x12PerTriggerMetadata\x12\x0c\n\x04loss\x18\x01 \x01(\x02\"4\n\x11PerSampleMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0c\n\x04loss\x18\x02 \x01(\x02\"\x1a\n\x18TrainingMetadataResponse2\xf7\x01\n\x11MetadataProcessor\x12h\n\x11register_pipeline\x12+.metadata_processor.RegisterPipelineRequest\x1a$.metadata_processor.PipelineResponse\"\x00\x12x\n\x19process_training_metadata\x12+.metadata_processor.TrainingMetadataRequest\x1a,.metadata_processor.TrainingMetadataResponse\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'metadata_processor_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _REGISTERPIPELINEREQUEST._serialized_start=48 - _REGISTERPIPELINEREQUEST._serialized_end=118 - _PIPELINERESPONSE._serialized_start=120 - _PIPELINERESPONSE._serialized_end=138 - _TRAININGMETADATAREQUEST._serialized_start=141 - _TRAININGMETADATAREQUEST._serialized_end=337 - _PERTRIGGERMETADATA._serialized_start=339 - _PERTRIGGERMETADATA._serialized_end=373 - _PERSAMPLEMETADATA._serialized_start=375 - _PERSAMPLEMETADATA._serialized_end=427 - _TRAININGMETADATARESPONSE._serialized_start=429 - _TRAININGMETADATARESPONSE._serialized_end=455 - _METADATAPROCESSOR._serialized_start=458 - _METADATAPROCESSOR._serialized_end=705 + DESCRIPTOR._options = None + _REGISTERPIPELINEREQUEST._serialized_start = 48 + _REGISTERPIPELINEREQUEST._serialized_end = 118 + _PIPELINERESPONSE._serialized_start = 120 + _PIPELINERESPONSE._serialized_end = 138 + _TRAININGMETADATAREQUEST._serialized_start = 141 + _TRAININGMETADATAREQUEST._serialized_end = 337 + _PERTRIGGERMETADATA._serialized_start = 339 + _PERTRIGGERMETADATA._serialized_end = 373 + _PERSAMPLEMETADATA._serialized_start = 375 + _PERSAMPLEMETADATA._serialized_end = 427 + _TRAININGMETADATARESPONSE._serialized_start = 429 + _TRAININGMETADATARESPONSE._serialized_end = 455 + _METADATAPROCESSOR._serialized_start = 458 + _METADATAPROCESSOR._serialized_end = 705 # @@protoc_insertion_point(module_scope) diff --git a/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2_grpc.py b/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2_grpc.py index 247853bf8..9969b3f03 100644 --- a/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2_grpc.py +++ b/modyn/metadata_processor/internal/grpc/generated/metadata_processor_pb2_grpc.py @@ -14,15 +14,15 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.register_pipeline = channel.unary_unary( - '/metadata_processor.MetadataProcessor/register_pipeline', - request_serializer=metadata__processor__pb2.RegisterPipelineRequest.SerializeToString, - response_deserializer=metadata__processor__pb2.PipelineResponse.FromString, - ) + '/metadata_processor.MetadataProcessor/register_pipeline', + request_serializer=metadata__processor__pb2.RegisterPipelineRequest.SerializeToString, + response_deserializer=metadata__processor__pb2.PipelineResponse.FromString, + ) self.process_training_metadata = channel.unary_unary( - '/metadata_processor.MetadataProcessor/process_training_metadata', - request_serializer=metadata__processor__pb2.TrainingMetadataRequest.SerializeToString, - response_deserializer=metadata__processor__pb2.TrainingMetadataResponse.FromString, - ) + '/metadata_processor.MetadataProcessor/process_training_metadata', + request_serializer=metadata__processor__pb2.TrainingMetadataRequest.SerializeToString, + response_deserializer=metadata__processor__pb2.TrainingMetadataResponse.FromString, + ) class MetadataProcessorServicer(object): @@ -43,56 +43,57 @@ def process_training_metadata(self, request, context): def add_MetadataProcessorServicer_to_server(servicer, server): rpc_method_handlers = { - 'register_pipeline': grpc.unary_unary_rpc_method_handler( - servicer.register_pipeline, - request_deserializer=metadata__processor__pb2.RegisterPipelineRequest.FromString, - response_serializer=metadata__processor__pb2.PipelineResponse.SerializeToString, - ), - 'process_training_metadata': grpc.unary_unary_rpc_method_handler( - servicer.process_training_metadata, - request_deserializer=metadata__processor__pb2.TrainingMetadataRequest.FromString, - response_serializer=metadata__processor__pb2.TrainingMetadataResponse.SerializeToString, - ), + 'register_pipeline': grpc.unary_unary_rpc_method_handler( + servicer.register_pipeline, + request_deserializer=metadata__processor__pb2.RegisterPipelineRequest.FromString, + response_serializer=metadata__processor__pb2.PipelineResponse.SerializeToString, + ), + 'process_training_metadata': grpc.unary_unary_rpc_method_handler( + servicer.process_training_metadata, + request_deserializer=metadata__processor__pb2.TrainingMetadataRequest.FromString, + response_serializer=metadata__processor__pb2.TrainingMetadataResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'metadata_processor.MetadataProcessor', rpc_method_handlers) + 'metadata_processor.MetadataProcessor', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. + + class MetadataProcessor(object): """Missing associated documentation comment in .proto file.""" @staticmethod def register_pipeline(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/metadata_processor.MetadataProcessor/register_pipeline', - metadata__processor__pb2.RegisterPipelineRequest.SerializeToString, - metadata__processor__pb2.PipelineResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + metadata__processor__pb2.RegisterPipelineRequest.SerializeToString, + metadata__processor__pb2.PipelineResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def process_training_metadata(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/metadata_processor.MetadataProcessor/process_training_metadata', - metadata__processor__pb2.TrainingMetadataRequest.SerializeToString, - metadata__processor__pb2.TrainingMetadataResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + metadata__processor__pb2.TrainingMetadataRequest.SerializeToString, + metadata__processor__pb2.TrainingMetadataResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/modyn/model_storage/internal/grpc/generated/model_storage_pb2_grpc.py b/modyn/model_storage/internal/grpc/generated/model_storage_pb2_grpc.py index 454f847bb..fb5d3afc2 100644 --- a/modyn/model_storage/internal/grpc/generated/model_storage_pb2_grpc.py +++ b/modyn/model_storage/internal/grpc/generated/model_storage_pb2_grpc.py @@ -14,20 +14,20 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.RegisterModel = channel.unary_unary( - '/modyn.model_storage.ModelStorage/RegisterModel', - request_serializer=model__storage__pb2.RegisterModelRequest.SerializeToString, - response_deserializer=model__storage__pb2.RegisterModelResponse.FromString, - ) + '/modyn.model_storage.ModelStorage/RegisterModel', + request_serializer=model__storage__pb2.RegisterModelRequest.SerializeToString, + response_deserializer=model__storage__pb2.RegisterModelResponse.FromString, + ) self.FetchModel = channel.unary_unary( - '/modyn.model_storage.ModelStorage/FetchModel', - request_serializer=model__storage__pb2.FetchModelRequest.SerializeToString, - response_deserializer=model__storage__pb2.FetchModelResponse.FromString, - ) + '/modyn.model_storage.ModelStorage/FetchModel', + request_serializer=model__storage__pb2.FetchModelRequest.SerializeToString, + response_deserializer=model__storage__pb2.FetchModelResponse.FromString, + ) self.DeleteModel = channel.unary_unary( - '/modyn.model_storage.ModelStorage/DeleteModel', - request_serializer=model__storage__pb2.DeleteModelRequest.SerializeToString, - response_deserializer=model__storage__pb2.DeleteModelResponse.FromString, - ) + '/modyn.model_storage.ModelStorage/DeleteModel', + request_serializer=model__storage__pb2.DeleteModelRequest.SerializeToString, + response_deserializer=model__storage__pb2.DeleteModelResponse.FromString, + ) class ModelStorageServicer(object): @@ -54,78 +54,79 @@ def DeleteModel(self, request, context): def add_ModelStorageServicer_to_server(servicer, server): rpc_method_handlers = { - 'RegisterModel': grpc.unary_unary_rpc_method_handler( - servicer.RegisterModel, - request_deserializer=model__storage__pb2.RegisterModelRequest.FromString, - response_serializer=model__storage__pb2.RegisterModelResponse.SerializeToString, - ), - 'FetchModel': grpc.unary_unary_rpc_method_handler( - servicer.FetchModel, - request_deserializer=model__storage__pb2.FetchModelRequest.FromString, - response_serializer=model__storage__pb2.FetchModelResponse.SerializeToString, - ), - 'DeleteModel': grpc.unary_unary_rpc_method_handler( - servicer.DeleteModel, - request_deserializer=model__storage__pb2.DeleteModelRequest.FromString, - response_serializer=model__storage__pb2.DeleteModelResponse.SerializeToString, - ), + 'RegisterModel': grpc.unary_unary_rpc_method_handler( + servicer.RegisterModel, + request_deserializer=model__storage__pb2.RegisterModelRequest.FromString, + response_serializer=model__storage__pb2.RegisterModelResponse.SerializeToString, + ), + 'FetchModel': grpc.unary_unary_rpc_method_handler( + servicer.FetchModel, + request_deserializer=model__storage__pb2.FetchModelRequest.FromString, + response_serializer=model__storage__pb2.FetchModelResponse.SerializeToString, + ), + 'DeleteModel': grpc.unary_unary_rpc_method_handler( + servicer.DeleteModel, + request_deserializer=model__storage__pb2.DeleteModelRequest.FromString, + response_serializer=model__storage__pb2.DeleteModelResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'modyn.model_storage.ModelStorage', rpc_method_handlers) + 'modyn.model_storage.ModelStorage', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. + + class ModelStorage(object): """Missing associated documentation comment in .proto file.""" @staticmethod def RegisterModel(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/modyn.model_storage.ModelStorage/RegisterModel', - model__storage__pb2.RegisterModelRequest.SerializeToString, - model__storage__pb2.RegisterModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + model__storage__pb2.RegisterModelRequest.SerializeToString, + model__storage__pb2.RegisterModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def FetchModel(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/modyn.model_storage.ModelStorage/FetchModel', - model__storage__pb2.FetchModelRequest.SerializeToString, - model__storage__pb2.FetchModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + model__storage__pb2.FetchModelRequest.SerializeToString, + model__storage__pb2.FetchModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def DeleteModel(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/modyn.model_storage.ModelStorage/DeleteModel', - model__storage__pb2.DeleteModelRequest.SerializeToString, - model__storage__pb2.DeleteModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + model__storage__pb2.DeleteModelRequest.SerializeToString, + model__storage__pb2.DeleteModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/modyn/playground/playground.cpp b/modyn/playground/playground.cpp index 0a2251e1b..b93e351f7 100644 --- a/modyn/playground/playground.cpp +++ b/modyn/playground/playground.cpp @@ -1,3 +1,3 @@ #include -int main() { std::cout << "Hi, I'm Modyn! This is the playground." << std::endl; } +int main() { std::cout << "Hi, I'm Modyn! This is the playground." << '\n'; } diff --git a/modyn/protos/README.md b/modyn/protos/README.md index 16eb38247..2cd8f4492 100644 --- a/modyn/protos/README.md +++ b/modyn/protos/README.md @@ -11,7 +11,7 @@ This assumes python 3.6+ is installed. First move to the directory where you want to generate the python files. Then run the following command: -` python -m grpc_tools.protoc -I../../../../protos --python_out=. --grpc_python_out=. --mypy_out=. ../../../../protos/[component_name].proto` +`python -m grpc_tools.protoc -I../../../../protos --python_out=. --grpc_python_out=. --mypy_out=. ../../../../protos/[component_name].proto` This will generate the following files: - [component_name]_pb2.py diff --git a/modyn/protos/storage.proto b/modyn/protos/storage.proto index edb579f13..d0e4cbc09 100644 --- a/modyn/protos/storage.proto +++ b/modyn/protos/storage.proto @@ -1,7 +1,5 @@ syntax = "proto3"; -import "google/protobuf/empty.proto"; - package modyn.storage; service Storage { @@ -18,7 +16,7 @@ service Storage { returns (DatasetAvailableResponse) {} rpc RegisterNewDataset(RegisterNewDatasetRequest) returns (RegisterNewDatasetResponse) {} - rpc GetCurrentTimestamp(google.protobuf.Empty) + rpc GetCurrentTimestamp(GetCurrentTimestampRequest) returns (GetCurrentTimestampResponse) {} rpc DeleteDataset(DatasetAvailableRequest) returns (DeleteDatasetResponse) {} rpc DeleteData(DeleteDataRequest) returns (DeleteDataResponse) {} @@ -35,6 +33,9 @@ message GetResponse { repeated int64 labels = 3; } +// https://github.com/grpc/grpc/issues/15937 +message GetCurrentTimestampRequest {} + message GetNewDataSinceRequest { string dataset_id = 1; int64 timestamp = 2; diff --git a/modyn/selector/internal/grpc/generated/selector_pb2_grpc.py b/modyn/selector/internal/grpc/generated/selector_pb2_grpc.py index 8f9d0f626..9efed2b5d 100644 --- a/modyn/selector/internal/grpc/generated/selector_pb2_grpc.py +++ b/modyn/selector/internal/grpc/generated/selector_pb2_grpc.py @@ -14,25 +14,25 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.register_pipeline = channel.unary_unary( - '/selector.Selector/register_pipeline', - request_serializer=selector__pb2.RegisterPipelineRequest.SerializeToString, - response_deserializer=selector__pb2.PipelineResponse.FromString, - ) + '/selector.Selector/register_pipeline', + request_serializer=selector__pb2.RegisterPipelineRequest.SerializeToString, + response_deserializer=selector__pb2.PipelineResponse.FromString, + ) self.get_sample_keys_and_weights = channel.unary_stream( - '/selector.Selector/get_sample_keys_and_weights', - request_serializer=selector__pb2.GetSamplesRequest.SerializeToString, - response_deserializer=selector__pb2.SamplesResponse.FromString, - ) + '/selector.Selector/get_sample_keys_and_weights', + request_serializer=selector__pb2.GetSamplesRequest.SerializeToString, + response_deserializer=selector__pb2.SamplesResponse.FromString, + ) self.inform_data = channel.unary_unary( '/selector.Selector/inform_data', request_serializer=selector__pb2.DataInformRequest.SerializeToString, response_deserializer=selector__pb2.DataInformResponse.FromString, ) self.inform_data_and_trigger = channel.unary_unary( - '/selector.Selector/inform_data_and_trigger', - request_serializer=selector__pb2.DataInformRequest.SerializeToString, - response_deserializer=selector__pb2.TriggerResponse.FromString, - ) + '/selector.Selector/inform_data_and_trigger', + request_serializer=selector__pb2.DataInformRequest.SerializeToString, + response_deserializer=selector__pb2.TriggerResponse.FromString, + ) self.get_number_of_samples = channel.unary_unary( '/selector.Selector/get_number_of_samples', request_serializer=selector__pb2.GetNumberOfSamplesRequest.SerializeToString, @@ -199,59 +199,60 @@ def add_SelectorServicer_to_server(servicer, server): ), } generic_handler = grpc.method_handlers_generic_handler( - 'selector.Selector', rpc_method_handlers) + 'selector.Selector', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. + + class Selector(object): """Missing associated documentation comment in .proto file.""" @staticmethod def register_pipeline(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/selector.Selector/register_pipeline', - selector__pb2.RegisterPipelineRequest.SerializeToString, - selector__pb2.PipelineResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + selector__pb2.RegisterPipelineRequest.SerializeToString, + selector__pb2.PipelineResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def get_sample_keys_and_weights(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/selector.Selector/get_sample_keys_and_weights', - selector__pb2.GetSamplesRequest.SerializeToString, - selector__pb2.SamplesResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + selector__pb2.GetSamplesRequest.SerializeToString, + selector__pb2.SamplesResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def inform_data(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/selector.Selector/inform_data', selector__pb2.DataInformRequest.SerializeToString, selector__pb2.DataInformResponse.FromString, @@ -260,37 +261,37 @@ def inform_data(request, @staticmethod def inform_data_and_trigger(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/selector.Selector/inform_data_and_trigger', - selector__pb2.DataInformRequest.SerializeToString, - selector__pb2.TriggerResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + selector__pb2.DataInformRequest.SerializeToString, + selector__pb2.TriggerResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def get_number_of_samples(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/selector.Selector/get_number_of_samples', - selector__pb2.GetNumberOfSamplesRequest.SerializeToString, - selector__pb2.NumberOfSamplesResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + selector__pb2.GetNumberOfSamplesRequest.SerializeToString, + selector__pb2.NumberOfSamplesResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def get_status_bar_scale(request, @@ -311,20 +312,20 @@ def get_status_bar_scale(request, @staticmethod def get_number_of_partitions(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/selector.Selector/get_number_of_partitions', - selector__pb2.GetNumberOfPartitionsRequest.SerializeToString, - selector__pb2.NumberOfPartitionsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + selector__pb2.GetNumberOfPartitionsRequest.SerializeToString, + selector__pb2.NumberOfPartitionsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def get_available_labels(request, @@ -345,15 +346,15 @@ def get_available_labels(request, @staticmethod def get_selection_strategy(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/selector.Selector/get_selection_strategy', selector__pb2.GetSelectionStrategyRequest.SerializeToString, selector__pb2.SelectionStrategyResponse.FromString, diff --git a/modyn/storage/CMakeLists.txt b/modyn/storage/CMakeLists.txt new file mode 100644 index 000000000..8b69e8171 --- /dev/null +++ b/modyn/storage/CMakeLists.txt @@ -0,0 +1,11 @@ +### Make modyn-storage-library lib available as target in next steps ### +add_library(modyn-storage-library) + +set(MODYN_STORAGE_CMAKE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cmake) + +##### modyn-storage-library ##### +add_subdirectory(src/) + +### Main binary ### +add_executable(modyn-storage src/main.cpp) +target_link_libraries(modyn-storage PRIVATE modyn modyn-storage-library argparse spdlog) \ No newline at end of file diff --git a/modyn/storage/README.md b/modyn/storage/README.md index e8b967228..fb195436d 100644 --- a/modyn/storage/README.md +++ b/modyn/storage/README.md @@ -2,87 +2,113 @@ This is the storage submodule. -Storage is the abstraction layer for the data storage. It is responsible for retrieving samples from the actual storage systems and providing them to the GPU nodes for training upon request. The storage component is started using `modyn-storage config.yaml`. The script should be in PATH after installing the `modyn` module. The configuration file describes the system setup. +Storage is the abstraction layer for the data storage. +It is responsible for retrieving samples from the actual storage systems and providing them to the GPU nodes for training upon request. +The storage component is started using `modyn-storage config.yaml`. +The binary should be in PATH after building the `modyn` module. +The configuration file describes the system setup. --- -## How the storage abstraction works: +## How the storage abstraction works -The storage abstraction works with the concept of datasets. Each dataset is identified by a unique name and describes a set of files that are stored in a storage system (for more information see the subsection on [How the storage database works](#how-the-storage-database-works)). Each file may contain one or more samples. A dataset is defined by a filesystem wrapper and a file wrapper. The filesystem wrapper describes how to access the underlying filesystem, while the file wrapper describes how to access the samples within the file. The storage abstraction is designed to be flexible and allow for different storage systems and file formats. +The storage abstraction works with the concept of datasets. +Each dataset is identified by a unique name and describes a set of files that are stored in a storage system (for more information see the subsection on [How the storage database works](#how-the-storage-database-works)). +Each file may contain one or more samples. +A dataset is defined by a filesystem wrapper and a file wrapper. +The filesystem wrapper describes how to access the underlying filesystem, while the file wrapper describes how to access the samples within the file. +The storage abstraction is designed to be flexible and allow for different storage systems and file formats. -### Filesystem wrappers: +### Filesystem wrappers The following filesystem wrappers are currently implemented: -- `local`: Accesses the local filesystem +- `LocalFilesystemWrapper`: Accesses the local filesystem. Future filesystem wrappers may include: -- `s3`: Accesses the Amazon S3 storage system -- `gcs`: Accesses the Google Cloud Storage system +- `s3`: Accesses the Amazon S3 storage system. +- `gcs`: Accesses the Google Cloud Storage system. -See the `modyn/storage/internal/filesystem_wrappers` directory for more information. +See the `modyn/storage/include/internal/filesystem_wrapper` directory for more information. **How to add a new filesystem wrapper:** -To add a new filesystem wrapper, you need to implement the `AbstractFilesystemWrapper` class. The class is defined in `modyn/storage/internal/filesystem_wrapper/abstractfilesystem_wrapper.py`. +To add a new filesystem wrapper, you need to implement the `FilesystemWrapper` abstract class. +The class is defined in `modyn/storage/include/internal/filesystem_wrapper/filesystem_wrapper.hpp`. -### File wrappers: +### File wrappers The following file wrappers are currently implemented: -- `single_sample`: Each file contains a single sample +- `SingleSampleFileWrapper`: Each file contains a single sample. +- `BinaryFileWrapper`: Each file contains columns and row in a binary format. +- `CsvFileWrapper`: Each file contains columns and rows in a csv format. Future file wrappers may include: -- `tfrecord`: Each file contains multiple samples in the [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) format -- `hdf5`: Each file contains multiple samples in the [HDF5](https://www.hdfgroup.org/solutions/hdf5/) format -- `parquet`: Each file contains multiple samples in the [Parquet](https://parquet.apache.org/) format +- `tfrecord`: Each file contains multiple samples in the [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) format. +- `hdf5`: Each file contains multiple samples in the [HDF5](https://www.hdfgroup.org/solutions/hdf5/) format. +- `parquet`: Each file contains multiple samples in the [Parquet](https://parquet.apache.org/) format. -See the `modyn/storage/internal/file_wrappers` directory for more information. +See the `modyn/storage/include/internal/file_wrapper` directory for more information. **How to add a new file wrapper:** -To add a new file wrapper, you need to implement the `AbstractFileWrapper` class. The class is defined in `modyn/storage/internal/file_wrapper/abstractfile_wrapper.py`. +To add a new file wrapper, you need to implement the `FileWrapper` class. +The class is defined in `modyn/storage/include/internal/file_wrapper/file_wrapper.hpp`. --- -## How to add a dataset: +## How to add a dataset There are two ways to add a dataset to the storage abstraction: -- Define the dataset in the configuration file and start the storage component using `modyn-storage path/to/config.yaml`. If the dataset is not yet in the database, it will be added automatically. If the dataset is already in the database, the database entry will be updated. -- Register the dataset using the grpc interface. The grpc interface is defined in `modyn/protos/storage.proto`. The call is `RegisterNewDataset`. +- Define the dataset in the configuration file and start the storage component using `modyn-storage path/to/config.yaml`. + If the dataset is not yet in the database, it will be added automatically. + If the dataset is already in the database, the database entry will be updated. +- Register the dataset using the grpc interface. + The grpc interface is defined in `modyn/protos/storage.proto`. + The call is `RegisterNewDataset`. --- -## How to add a file to a dataset (NewFileWatcher): +## How to add a file to a dataset (NewFileWatcher) -A file is added to the storage abstraction automatically when the file is created in the underlying storage system. The storage abstraction will periodically check the underlying storage system for new files. If a new file is found, it will be added to the database. The component that is responsible for checking the underlying storage system is called the `NewFileWatcher`. The `NewFileWatcher` is started automatically when the storage component is started. The `NewFileWatcher` is defined in `modyn/storage/internal/new_file_watcher.py`. The `NewFileWatcher` periodically checks for each dataset if there are new files in the underlying storage system. If a new file is found, it and the samples in the file are added to the database. - -Files and samples are expected to be added by a separate component or an altogether different system. The `Storage` component is only responsible for checking for new files and adding them to the database as well as providing the samples to the GPU nodes. It is thus a read-only component. +A file is added to the storage abstraction automatically when the file is created in the underlying storage system. +The storage abstraction will periodically check the underlying storage system for new files. +If a new file is found, it will be added to the database. +The component that is responsible for checking the underlying storage systems is called the `FileWatchdog`. +The `FileWatchdog` is started automatically when the storage component is started. +The `FileWatchdog` is defined in `modyn/storage/include/internal/file_watcher/file_watchdog.hpp`. +The `FileWatchdog` periodically checks for each dataset if there are new files in the underlying storage system with an instance of a `FileWatcher` as defined in `modyn/storage/include/internal/file_watcher/file_watcher.hpp`. +If a new file is found, it and the samples in the file are added to the database. +Files and samples are expected to be added by a separate component or an altogether different system. +The `Storage` component is only responsible for checking for new files and adding them to the database as well as providing the samples to the GPU nodes. +It is thus a read-only component. --- -## How the storage database works: - -The storage abstraction uses a database to store information about the datasets. The database contains the following tables: - -- `datasets`: Contains information about the datasets - - `dataset_id`: The id of the dataset (primary key) - - `name`: The name of the dataset - - `description`: A description of the dataset - - `filesystem_wrapper_type`: The name of the filesystem wrapper - - `file_wrapper_type`: The name of the file wrapper - - `base_path`: The base path of the dataset -- `files`: Contains information about the files in the datasets - - `file_id`: The id of the file (primary key) - - `dataset_id`: The id of the dataset (foreign key to `datasets.dataset_id`) - - `path`: The path of the file - - `created_at`: The timestamp when the file was created - - `updated_at`: The timestamp when the file was updated - - `number_of_samples`: The number of samples in the file -- `samples`: Contains information about the samples in the files - - `sample_id`: The id of the sample (primary key) - - `file_id`: The id of the file (foreign key to `files.file_id`) - - `index`: The index of the sample in the file \ No newline at end of file +## How the storage database works + +The storage abstraction uses a database to store information about the datasets. +The database contains the following tables: + +- `datasets`: Contains information about the datasets. + - `dataset_id`: The id of the dataset (primary key). + - `name`: The name of the dataset. + - `description`: A description of the dataset. + - `filesystem_wrapper_type`: The name of the filesystem wrapper. + - `file_wrapper_type`: The name of the file wrapper. + - `base_path`: The base path of the dataset. +- `files`: Contains information about the files in the datasets. + - `file_id`: The id of the file (primary key). + - `dataset_id`: The id of the dataset (foreign key to `datasets.dataset_id`). + - `path`: The path of the file. + - `created_at`: The timestamp when the file was created. + - `updated_at`: The timestamp when the file was updated. + - `number_of_samples`: The number of samples in the file. +- `samples`: Contains information about the samples in the files. + - `sample_id`: The id of the sample (primary key). + - `file_id`: The id of the file (foreign key to `files.file_id`). + - `index`: The index of the sample in the file. diff --git a/modyn/storage/__init__.py b/modyn/storage/__init__.py index 4f7969a3a..982984594 100644 --- a/modyn/storage/__init__.py +++ b/modyn/storage/__init__.py @@ -1,12 +1,11 @@ -"""Storage module. +""" +Storage module. -The storage module contains all classes and functions related to the storage and retrieval of data. +The storage module contains all classes and functions related the evaluation of models. """ import os -from .storage import Storage # noqa: F401 - files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") __all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/include/internal/database/cursor_handler.hpp b/modyn/storage/include/internal/database/cursor_handler.hpp new file mode 100644 index 000000000..50b7770e8 --- /dev/null +++ b/modyn/storage/include/internal/database/cursor_handler.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include + +#include + +#include "internal/database/storage_database_connection.hpp" + +namespace modyn::storage { + +struct SampleRecord { + int64_t id; + int64_t column_1; + int64_t column_2; +}; + +/* +Implements a server-side cursor on Postgres and emulates it for sqlite. +For a given query, results are returned (using the yield_per function) buffered, to avoid filling up memory. +*/ +class CursorHandler { + public: + CursorHandler(soci::session& session, DatabaseDriver driver, const std::string& query, std::string cursor_name, + int16_t number_of_columns = 3) + : driver_{driver}, + session_{session}, + query_{query}, + cursor_name_{std::move(cursor_name)}, + number_of_columns_{number_of_columns} { + // ncol = 0 or = 1 means that we only return the first column in the result of the query (typically, the ID) + // ncol = 2 returns the second column as well (typically what you want if you want an id + some property) + // ncol = 3 returns the third as well + // This could be generalized but currently is hardcoded. + // A SampleRecord is populated and (as can be seen above) only has three properties per row. + ASSERT(number_of_columns <= 3 && number_of_columns >= 0, "We currently only support 0 - 3 columns."); + + switch (driver_) { + case DatabaseDriver::POSTGRESQL: { + auto* postgresql_session_backend = static_cast(session_.get_backend()); + PGconn* conn = postgresql_session_backend->conn_; + + const std::string declare_cursor = fmt::format("DECLARE {} CURSOR WITH HOLD FOR {}", cursor_name_, query); + PGresult* result = PQexec(conn, declare_cursor.c_str()); + + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + SPDLOG_ERROR("Cursor declaration failed: {}", PQerrorMessage(conn)); + PQclear(result); + break; + } + + PQclear(result); + + postgresql_conn_ = conn; + break; + } + case DatabaseDriver::SQLITE3: { + rs_ = std::make_unique>(session_.prepare << query); + break; + } + default: + FAIL("Unsupported database driver"); + } + + open_ = true; + } + ~CursorHandler() { close_cursor(); } + CursorHandler(const CursorHandler&) = delete; + CursorHandler& operator=(const CursorHandler&) = delete; + CursorHandler(CursorHandler&&) = delete; + CursorHandler& operator=(CursorHandler&&) = delete; + std::vector yield_per(uint64_t number_of_rows_to_fetch); + void close_cursor(); + + private: + void check_cursor_initialized(); + DatabaseDriver driver_; + soci::session& session_; + std::string query_; + std::string cursor_name_; + int16_t number_of_columns_; + std::unique_ptr> rs_{nullptr}; + PGconn* postgresql_conn_{nullptr}; + bool open_{false}; +}; +} // namespace modyn::storage \ No newline at end of file diff --git a/modyn/storage/include/internal/database/storage_database_connection.hpp b/modyn/storage/include/internal/database/storage_database_connection.hpp new file mode 100644 index 000000000..0439677d8 --- /dev/null +++ b/modyn/storage/include/internal/database/storage_database_connection.hpp @@ -0,0 +1,145 @@ +#pragma once + +#include + +#include + +#include "internal/file_wrapper/file_wrapper.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" +#include "modyn/utils/utils.hpp" +#include "soci/postgresql/soci-postgresql.h" +#include "soci/soci.h" +#include "soci/sqlite3/soci-sqlite3.h" +#include "yaml-cpp/yaml.h" + +namespace modyn::storage { + +enum class DatabaseDriver { POSTGRESQL, SQLITE3 }; + +class StorageDatabaseConnection { + public: + explicit StorageDatabaseConnection(const YAML::Node& config) { + if (!config["storage"]["database"]) { + FAIL("No database configuration found"); + } + drivername_ = get_drivername(config); + username_ = config["storage"]["database"]["username"].as(); + password_ = config["storage"]["database"]["password"].as(); + host_ = config["storage"]["database"]["host"].as(); + port_ = config["storage"]["database"]["port"].as(); + database_ = config["storage"]["database"]["database"].as(); + if (config["storage"]["database"]["hash_partition_modulus"]) { + hash_partition_modulus_ = config["storage"]["database"]["hash_partition_modulus"].as(); + } + if (config["storage"]["sample_table_unlogged"]) { + sample_table_unlogged_ = config["storage"]["sample_table_unlogged"].as(); + } + } + void create_tables() const; + bool add_dataset(const std::string& name, const std::string& base_path, + const FilesystemWrapperType& filesystem_wrapper_type, const FileWrapperType& file_wrapper_type, + const std::string& description, const std::string& version, const std::string& file_wrapper_config, + bool ignore_last_timestamp, int64_t file_watcher_interval = 5) const; + bool delete_dataset(const std::string& name, int64_t dataset_id) const; + bool add_sample_dataset_partition(const std::string& dataset_name) const; + soci::session get_session() const; + DatabaseDriver get_drivername() const { return drivername_; } + template + static T get_from_row(soci::row& row, uint64_t pos) { + // This function is needed to make dispatching soci's typing system easier... + const soci::column_properties& props = row.get_properties(pos); + if constexpr (std::is_same_v) { + switch (props.get_data_type()) { + case soci::dt_long_long: + static_assert(sizeof(long long) <= sizeof(int64_t), // NOLINT(google-runtime-int) + "We currently assume long long is equal to or less than 64 bit."); + return static_cast(row.get(pos)); // NOLINT(google-runtime-int) + case soci::dt_integer: + // NOLINTNEXTLINE(google-runtime-int) + static_assert(sizeof(int) <= sizeof(int64_t), "We currently assume int is equal to or less than 64 bit."); + return static_cast(row.get(pos)); // NOLINT(google-runtime-int) + case soci::dt_unsigned_long_long: + FAIL(fmt::format("Tried to extract integer from unsigned long long column {}", props.get_name())); + break; + case soci::dt_string: + FAIL(fmt::format("Tried to extract integer from string column {}", props.get_name())); + break; + case soci::dt_double: + FAIL(fmt::format("Tried to extract integer from double column {}", props.get_name())); + break; + case soci::dt_date: + FAIL(fmt::format("Tried to extract integer from data column {}", props.get_name())); + break; + default: + FAIL(fmt::format("Tried to extract integer from unknown data type ({}) column {}", + static_cast(props.get_data_type()), props.get_name())); + } + } + + if constexpr (std::is_same_v) { + switch (props.get_data_type()) { + case soci::dt_unsigned_long_long: + static_assert(sizeof(unsigned long long) <= sizeof(uint64_t), // NOLINT(google-runtime-int) + "We currently assume unsined long long is equal to or less than 64 bit."); + return static_cast(row.get(pos)); // NOLINT(google-runtime-int) + case soci::dt_long_long: + FAIL(fmt::format("Tried to extract unsigned long long from signed long long column {}", props.get_name())); + case soci::dt_integer: + FAIL(fmt::format("Tried to extract unsigned long long from signed integer column {}", props.get_name())); + case soci::dt_string: + FAIL(fmt::format("Tried to extract integer from string column {}", props.get_name())); + break; + case soci::dt_double: + FAIL(fmt::format("Tried to extract integer from double column {}", props.get_name())); + break; + case soci::dt_date: + FAIL(fmt::format("Tried to extract integer from data column {}", props.get_name())); + break; + default: + FAIL(fmt::format("Tried to extract integer from unknown data type ({}) column {}", + static_cast(props.get_data_type()), props.get_name())); + } + } + + if constexpr (std::is_same_v) { + switch (props.get_data_type()) { + case soci::dt_unsigned_long_long: + return static_cast(row.get(pos)); // NOLINT(google-runtime-int) + case soci::dt_long_long: + return static_cast(row.get(pos)); // NOLINT(google-runtime-int) + case soci::dt_integer: + return static_cast(row.get(pos)); // NOLINT(google-runtime-int) + case soci::dt_string: + FAIL(fmt::format("Tried to extract bool from string column {}", props.get_name())); + break; + case soci::dt_double: + FAIL(fmt::format("Tried to extract bool from double column {}", props.get_name())); + break; + case soci::dt_date: + FAIL(fmt::format("Tried to extract bool from data column {}", props.get_name())); + break; + default: + FAIL(fmt::format("Tried to extract bool from unknown data type ({}) column {}", + static_cast(props.get_data_type()), props.get_name())); + } + } + + const std::type_info& ti1 = typeid(T); + const std::string type_id = ti1.name(); + FAIL(fmt::format("Unsupported type in get_from_row: {}", type_id)); + } + + private: + static DatabaseDriver get_drivername(const YAML::Node& config); + int64_t get_dataset_id(const std::string& name) const; + std::string username_; + std::string password_; + std::string host_; + std::string port_; + std::string database_; + bool sample_table_unlogged_ = false; + int16_t hash_partition_modulus_ = 8; + DatabaseDriver drivername_; +}; + +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_watcher/file_watcher.hpp b/modyn/storage/include/internal/file_watcher/file_watcher.hpp new file mode 100644 index 000000000..5b1beca53 --- /dev/null +++ b/modyn/storage/include/internal/file_watcher/file_watcher.hpp @@ -0,0 +1,168 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "internal/database/storage_database_connection.hpp" +#include "internal/file_wrapper/file_wrapper.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper_utils.hpp" +#include "modyn/utils/utils.hpp" + +namespace modyn::storage { + +struct FileFrame { + // Struct to store file information for insertion into the database when watching a dataset. + int64_t file_id; + int64_t index; + int64_t label; +}; +class FileWatcher { + public: + explicit FileWatcher(const YAML::Node& config, int64_t dataset_id, std::atomic* stop_file_watcher, + int16_t insertion_threads = 1) + : stop_file_watcher{stop_file_watcher}, + config_{config}, + dataset_id_{dataset_id}, + insertion_threads_{insertion_threads}, + disable_multithreading_{insertion_threads <= 1}, + storage_database_connection_{StorageDatabaseConnection(config)} { + ASSERT(stop_file_watcher != nullptr, "stop_file_watcher_ is nullptr."); + SPDLOG_INFO("Initializing file watcher for dataset {}.", dataset_id_); + + if (config_["storage"]["sample_dbinsertion_batchsize"]) { + sample_dbinsertion_batchsize_ = config_["storage"]["sample_dbinsertion_batchsize"].as(); + } + if (config_["storage"]["force_fallback"]) { + force_fallback_ = config["storage"]["force_fallback"].as(); + } + soci::session session = storage_database_connection_.get_session(); + + std::string dataset_path; + auto filesystem_wrapper_type_int = static_cast(FilesystemWrapperType::INVALID_FSW); + std::string file_wrapper_config; + auto file_wrapper_type_id = static_cast(FileWrapperType::INVALID_FW); + try { + session << "SELECT base_path, filesystem_wrapper_type, file_wrapper_type, file_wrapper_config FROM datasets " + "WHERE dataset_id = :dataset_id", + soci::into(dataset_path), soci::into(filesystem_wrapper_type_int), soci::into(file_wrapper_type_id), + soci::into(file_wrapper_config), soci::use(dataset_id_); + } catch (const soci::soci_error& e) { + SPDLOG_ERROR("Error while reading dataset path and filesystem wrapper type from database: {}", e.what()); + *stop_file_watcher = true; + return; + } + + session.close(); + + filesystem_wrapper_type_ = static_cast(filesystem_wrapper_type_int); + + SPDLOG_INFO("FileWatcher for dataset {} uses path {}, file_wrapper_id {} and file_system_id {}", dataset_id_, + dataset_path, file_wrapper_type_id, filesystem_wrapper_type_int); + + if (dataset_path.empty()) { + SPDLOG_ERROR("Dataset with id {} not found.", dataset_id_); + *stop_file_watcher = true; + return; + } + + filesystem_wrapper = get_filesystem_wrapper(filesystem_wrapper_type_); + + dataset_path_ = dataset_path; + + if (!filesystem_wrapper->exists(dataset_path_) || !filesystem_wrapper->is_directory(dataset_path_)) { + SPDLOG_ERROR("Dataset path {} does not exist or is not a directory.", dataset_path_); + *stop_file_watcher = true; + return; + } + + if (file_wrapper_type_id == -1) { + SPDLOG_ERROR("Failed to get file wrapper type"); + *stop_file_watcher = true; + return; + } + + file_wrapper_type_ = static_cast(file_wrapper_type_id); + + if (file_wrapper_config.empty()) { + SPDLOG_ERROR("Failed to get file wrapper config"); + *stop_file_watcher = true; + return; + } + + file_wrapper_config_node_ = YAML::Load(file_wrapper_config); + + if (!file_wrapper_config_node_["file_extension"]) { + SPDLOG_ERROR("Config does not contain file_extension"); + *stop_file_watcher = true; + return; + } + + data_file_extension_ = file_wrapper_config_node_["file_extension"].as(); + + if (!disable_multithreading_) { + insertion_thread_pool_.reserve(insertion_threads_); + insertion_thread_exceptions_ = std::vector>(insertion_threads_); + } + SPDLOG_INFO("FileWatcher for dataset {} initialized", dataset_id_); + } + void run(); + void search_for_new_files_in_directory(const std::string& directory_path, int64_t timestamp); + void seek_dataset(soci::session& session); + void seek(soci::session& session); + static void handle_file_paths(std::vector::iterator file_paths_begin, + std::vector::iterator file_paths_end, FileWrapperType file_wrapper_type, + int64_t timestamp, FilesystemWrapperType filesystem_wrapper_type, int64_t dataset_id, + const YAML::Node* file_wrapper_config, const YAML::Node* config, + int64_t sample_dbinsertion_batchsize, bool force_fallback, + std::atomic* exception_thrown); + static void handle_files_for_insertion(std::vector& files_for_insertion, + const FileWrapperType& file_wrapper_type, int64_t dataset_id, + const YAML::Node& file_wrapper_config, int64_t sample_dbinsertion_batchsize, + bool force_fallback, soci::session& session, DatabaseDriver& database_driver, + const std::shared_ptr& filesystem_wrapper); + static void insert_file_samples(const std::vector& file_samples, int64_t dataset_id, bool force_fallback, + soci::session& session, DatabaseDriver& database_driver); + static int64_t insert_file(const std::string& file_path, int64_t dataset_id, + const std::shared_ptr& filesystem_wrapper, + const std::unique_ptr& file_wrapper, soci::session& session, + DatabaseDriver& database_driver); + static bool check_file_for_insertion(const std::string& file_path, bool ignore_last_timestamp, int64_t timestamp, + int64_t dataset_id, const std::shared_ptr& filesystem_wrapper, + soci::session& session); + static void postgres_copy_insertion(const std::vector& file_samples, int64_t dataset_id, + soci::session& session); + static void fallback_insertion(const std::vector& file_samples, int64_t dataset_id, + soci::session& session); + static int64_t insert_file(const std::string& file_path, int64_t dataset_id, soci::session& session, + uint64_t number_of_samples, int64_t modified_time); + static int64_t insert_file_using_returning_statement(const std::string& file_path, int64_t dataset_id, + soci::session& session, uint64_t number_of_samples, + int64_t modified_time); + std::atomic* stop_file_watcher; + std::shared_ptr filesystem_wrapper; + + private: + YAML::Node config_; + int64_t dataset_id_ = -1; + int16_t insertion_threads_ = 1; + bool disable_multithreading_ = false; + std::vector insertion_thread_pool_ = {}; + std::vector> insertion_thread_exceptions_ = {}; + int64_t sample_dbinsertion_batchsize_ = 1000000; + bool force_fallback_ = false; + StorageDatabaseConnection storage_database_connection_; + std::string dataset_path_; + FilesystemWrapperType filesystem_wrapper_type_; + FileWrapperType file_wrapper_type_; + YAML::Node file_wrapper_config_node_; + std::string data_file_extension_; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_watcher/file_watcher_watchdog.hpp b/modyn/storage/include/internal/file_watcher/file_watcher_watchdog.hpp new file mode 100644 index 000000000..c2af7fbfb --- /dev/null +++ b/modyn/storage/include/internal/file_watcher/file_watcher_watchdog.hpp @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "file_watcher.hpp" +#include "internal/database/storage_database_connection.hpp" +#include "modyn/utils/utils.hpp" + +namespace modyn::storage { + +class FileWatcherWatchdog { + public: + FileWatcherWatchdog(const YAML::Node& config, std::atomic* stop_file_watcher_watchdog, + std::atomic* request_storage_shutdown) + : config_{config}, + stop_file_watcher_watchdog_{stop_file_watcher_watchdog}, + request_storage_shutdown_{request_storage_shutdown}, + storage_database_connection_{StorageDatabaseConnection(config_)} { + ASSERT(stop_file_watcher_watchdog_ != nullptr, "stop_file_watcher_watchdog_ is nullptr."); + ASSERT(config_["storage"]["insertion_threads"], "Config does not contain insertion_threads"); + + if (config_["storage"]["file_watcher_watchdog_sleep_time_s"]) { + file_watcher_watchdog_sleep_time_s_ = config_["storage"]["file_watcher_watchdog_sleep_time_s"].as(); + } + } + void watch_file_watcher_threads(); + void start_file_watcher_thread(int64_t dataset_id); + void stop_file_watcher_thread(int64_t dataset_id); + void run(); + void stop() { + SPDLOG_INFO("FileWatcherWatchdog requesting storage shutdown!"); + stop_file_watcher_watchdog_->store(true); + request_storage_shutdown_->store(true); + } + std::vector get_running_file_watcher_threads(); + + private: + void stop_and_clear_all_file_watcher_threads(); + YAML::Node config_; + int64_t file_watcher_watchdog_sleep_time_s_ = 3; + std::map file_watchers_ = {}; + std::map file_watcher_threads_ = {}; + std::map file_watcher_dataset_retries_ = {}; + std::map> file_watcher_thread_stop_flags_ = {}; + // Used to stop the FileWatcherWatchdog thread from storage main thread + std::atomic* stop_file_watcher_watchdog_; + std::atomic* request_storage_shutdown_; + StorageDatabaseConnection storage_database_connection_; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_wrapper/binary_file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/binary_file_wrapper.hpp new file mode 100644 index 000000000..e2d9b191d --- /dev/null +++ b/modyn/storage/include/internal/file_wrapper/binary_file_wrapper.hpp @@ -0,0 +1,64 @@ +#pragma once + +#include + +#include +#include +#include + +#include "internal/file_wrapper/file_wrapper.hpp" +#include "modyn/utils/utils.hpp" + +namespace modyn::storage { +class BinaryFileWrapper : public FileWrapper { + public: + BinaryFileWrapper(const std::string& path, const YAML::Node& fw_config, + std::shared_ptr filesystem_wrapper) + : FileWrapper(path, fw_config, std::move(filesystem_wrapper)) { + ASSERT(filesystem_wrapper_ != nullptr, "Filesystem wrapper cannot be null."); + ASSERT(fw_config["record_size"], "record_size must be specified in the file wrapper config."); + ASSERT(fw_config["label_size"], "label_size be specified in the file wrapper config."); + + record_size_ = fw_config["record_size"].as(); + label_size_ = fw_config["label_size"].as(); + sample_size_ = record_size_ - label_size_; + validate_file_extension(); + file_size_ = filesystem_wrapper_->get_file_size(path); + + ASSERT(static_cast(record_size_ - label_size_) >= 1, + "Each record must have at least 1 byte of data other than the label."); + ASSERT(file_size_ % record_size_ == 0, "File size must be a multiple of the record size."); + + stream_ = filesystem_wrapper_->get_stream(path); + } + uint64_t get_number_of_samples() override; + int64_t get_label(uint64_t index) override; + std::vector get_all_labels() override; + std::vector get_sample(uint64_t index) override; + std::vector> get_samples(uint64_t start, uint64_t end) override; + std::vector> get_samples_from_indices(const std::vector& indices) override; + void validate_file_extension() override; + void delete_samples(const std::vector& indices) override; + void set_file_path(const std::string& path) override; + FileWrapperType get_type() override; + ~BinaryFileWrapper() override { + if (stream_->is_open()) { + stream_->close(); + } + } + BinaryFileWrapper(const BinaryFileWrapper&) = default; + BinaryFileWrapper& operator=(const BinaryFileWrapper&) = default; + BinaryFileWrapper(BinaryFileWrapper&&) = default; + BinaryFileWrapper& operator=(BinaryFileWrapper&&) = default; + + private: + static void validate_request_indices(uint64_t total_samples, const std::vector& indices); + static int64_t int_from_bytes(const unsigned char* begin, const unsigned char* end); + std::ifstream* get_stream(); + uint64_t record_size_; + uint64_t label_size_; + uint64_t file_size_; + uint64_t sample_size_; + std::shared_ptr stream_; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp new file mode 100644 index 000000000..a46372b51 --- /dev/null +++ b/modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include +#include + +#include "internal/file_wrapper/file_wrapper.hpp" +#include "modyn/utils/utils.hpp" + +namespace modyn::storage { + +class CsvFileWrapper : public FileWrapper { + public: + CsvFileWrapper(const std::string& path, const YAML::Node& fw_config, + std::shared_ptr filesystem_wrapper) + : FileWrapper{path, fw_config, std::move(filesystem_wrapper)} { + ASSERT(file_wrapper_config_["label_index"], "Please specify the index of the column that contains the label."); + label_index_ = file_wrapper_config_["label_index"].as(); + + if (file_wrapper_config_["separator"]) { + separator_ = file_wrapper_config_["separator"].as(); + } else { + separator_ = ','; + } + + bool ignore_first_line = false; + if (file_wrapper_config_["ignore_first_line"]) { + ignore_first_line = file_wrapper_config_["ignore_first_line"].as(); + } else { + ignore_first_line = false; + } + + ASSERT(filesystem_wrapper_->exists(path), "The file does not exist."); + + validate_file_extension(); + + label_params_ = rapidcsv::LabelParams(ignore_first_line ? 0 : -1); + + stream_ = filesystem_wrapper_->get_stream(path); + + doc_ = rapidcsv::Document(*stream_, label_params_, rapidcsv::SeparatorParams(separator_)); + } + + ~CsvFileWrapper() override { + if (stream_->is_open()) { + stream_->close(); + } + } + CsvFileWrapper(const CsvFileWrapper&) = default; + CsvFileWrapper& operator=(const CsvFileWrapper&) = default; + CsvFileWrapper(CsvFileWrapper&&) = default; + CsvFileWrapper& operator=(CsvFileWrapper&&) = default; + + uint64_t get_number_of_samples() override; + int64_t get_label(uint64_t index) override; + std::vector get_all_labels() override; + std::vector get_sample(uint64_t index) override; + std::vector> get_samples(uint64_t start, uint64_t end) override; + std::vector> get_samples_from_indices(const std::vector& indices) override; + void validate_file_extension() override; + void delete_samples(const std::vector& indices) override; + void set_file_path(const std::string& path) override; + FileWrapperType get_type() override; + + private: + char separator_; + uint64_t label_index_; + rapidcsv::Document doc_; + rapidcsv::LabelParams label_params_; + std::shared_ptr stream_; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_wrapper/file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/file_wrapper.hpp new file mode 100644 index 000000000..94df67fbc --- /dev/null +++ b/modyn/storage/include/internal/file_wrapper/file_wrapper.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include + +#include + +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" + +namespace modyn::storage { + +enum FileWrapperType { INVALID_FW, SINGLE_SAMPLE, BINARY, CSV }; + +class FileWrapper { + public: + FileWrapper(std::string path, const YAML::Node& fw_config, std::shared_ptr filesystem_wrapper) + : file_path_{std::move(path)}, + file_wrapper_config_{fw_config}, + filesystem_wrapper_{std::move(filesystem_wrapper)} {} + virtual uint64_t get_number_of_samples() = 0; + virtual int64_t get_label(uint64_t index) = 0; + virtual std::vector get_all_labels() = 0; + virtual std::vector get_sample(uint64_t index) = 0; + virtual std::vector> get_samples(uint64_t start, uint64_t end) = 0; + virtual std::vector> get_samples_from_indices(const std::vector& indices) = 0; + virtual void validate_file_extension() = 0; + virtual void delete_samples(const std::vector& indices) = 0; + virtual void set_file_path(const std::string& path) = 0; + virtual FileWrapperType get_type() = 0; + static FileWrapperType get_file_wrapper_type(const std::string& type) { + static const std::unordered_map FILE_WRAPPER_TYPE_MAP = { + {"SingleSampleFileWrapper", FileWrapperType::SINGLE_SAMPLE}, + {"BinaryFileWrapper", FileWrapperType::BINARY}, + {"CsvFileWrapper", FileWrapperType::CSV}}; + return FILE_WRAPPER_TYPE_MAP.at(type); + } + virtual ~FileWrapper() = default; + FileWrapper(const FileWrapper&) = default; + FileWrapper& operator=(const FileWrapper&) = default; + FileWrapper(FileWrapper&&) = default; + FileWrapper& operator=(FileWrapper&&) = default; + + protected: + std::string file_path_; + YAML::Node file_wrapper_config_; + std::shared_ptr filesystem_wrapper_; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_wrapper/file_wrapper_utils.hpp b/modyn/storage/include/internal/file_wrapper/file_wrapper_utils.hpp new file mode 100644 index 000000000..772e7259f --- /dev/null +++ b/modyn/storage/include/internal/file_wrapper/file_wrapper_utils.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "internal/file_wrapper/binary_file_wrapper.hpp" +#include "internal/file_wrapper/csv_file_wrapper.hpp" +#include "internal/file_wrapper/file_wrapper.hpp" +#include "internal/file_wrapper/single_sample_file_wrapper.hpp" +#include "modyn/utils/utils.hpp" + +namespace modyn::storage { + +std::unique_ptr get_file_wrapper(const std::string& path, const FileWrapperType& type, + const YAML::Node& file_wrapper_config, + const std::shared_ptr& filesystem_wrapper); + +} // namespace modyn::storage \ No newline at end of file diff --git a/modyn/storage/include/internal/file_wrapper/single_sample_file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/single_sample_file_wrapper.hpp new file mode 100644 index 000000000..5bf09fcbc --- /dev/null +++ b/modyn/storage/include/internal/file_wrapper/single_sample_file_wrapper.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "internal/file_wrapper/file_wrapper.hpp" + +namespace modyn::storage { + +class SingleSampleFileWrapper : public FileWrapper { + public: + SingleSampleFileWrapper(const std::string& path, const YAML::Node& fw_config, + std::shared_ptr filesystem_wrapper) + : FileWrapper(path, fw_config, std::move(filesystem_wrapper)) { + validate_file_extension(); + } + uint64_t get_number_of_samples() override; + int64_t get_label(uint64_t index) override; + std::vector get_all_labels() override; + std::vector get_sample(uint64_t index) override; + std::vector> get_samples(uint64_t start, uint64_t end) override; + std::vector> get_samples_from_indices(const std::vector& indices) override; + void validate_file_extension() override; + void delete_samples(const std::vector& indices) override; + void set_file_path(const std::string& path) override { file_path_ = path; } + FileWrapperType get_type() override; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/filesystem_wrapper/filesystem_wrapper.hpp b/modyn/storage/include/internal/filesystem_wrapper/filesystem_wrapper.hpp new file mode 100644 index 000000000..d9e10cd09 --- /dev/null +++ b/modyn/storage/include/internal/filesystem_wrapper/filesystem_wrapper.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include +#include +#include + +namespace modyn::storage { + +enum FilesystemWrapperType { INVALID_FSW, LOCAL }; + +class FilesystemWrapper { + public: + FilesystemWrapper() = default; + virtual std::vector get(const std::string& path) = 0; + virtual bool exists(const std::string& path) = 0; + virtual std::vector list(const std::string& path, bool recursive, std::string extension) = 0; + virtual bool is_directory(const std::string& path) = 0; + virtual bool is_file(const std::string& path) = 0; + virtual uint64_t get_file_size(const std::string& path) = 0; + virtual int64_t get_modified_time(const std::string& path) = 0; + virtual bool is_valid_path(const std::string& path) = 0; + virtual std::shared_ptr get_stream(const std::string& path) = 0; + virtual FilesystemWrapperType get_type() = 0; + virtual bool remove(const std::string& path) = 0; + static FilesystemWrapperType get_filesystem_wrapper_type(const std::string& type) { + static const std::unordered_map FILESYSTEM_WRAPPER_TYPE_MAP = { + {"LocalFilesystemWrapper", FilesystemWrapperType::LOCAL}, + }; + return FILESYSTEM_WRAPPER_TYPE_MAP.at(type); + } + virtual ~FilesystemWrapper() = default; + FilesystemWrapper(const FilesystemWrapper&) = default; + FilesystemWrapper& operator=(const FilesystemWrapper&) = default; + FilesystemWrapper(FilesystemWrapper&&) = default; + FilesystemWrapper& operator=(FilesystemWrapper&&) = default; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/filesystem_wrapper/filesystem_wrapper_utils.hpp b/modyn/storage/include/internal/filesystem_wrapper/filesystem_wrapper_utils.hpp new file mode 100644 index 000000000..92982acee --- /dev/null +++ b/modyn/storage/include/internal/filesystem_wrapper/filesystem_wrapper_utils.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" +#include "internal/filesystem_wrapper/local_filesystem_wrapper.hpp" +#include "modyn/utils/utils.hpp" + +namespace modyn::storage { + +std::shared_ptr get_filesystem_wrapper(const FilesystemWrapperType& type); + +} // namespace modyn::storage \ No newline at end of file diff --git a/modyn/storage/include/internal/filesystem_wrapper/local_filesystem_wrapper.hpp b/modyn/storage/include/internal/filesystem_wrapper/local_filesystem_wrapper.hpp new file mode 100644 index 000000000..bf926469c --- /dev/null +++ b/modyn/storage/include/internal/filesystem_wrapper/local_filesystem_wrapper.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" + +namespace modyn::storage { +class LocalFilesystemWrapper : public FilesystemWrapper { + public: + LocalFilesystemWrapper() = default; + std::vector get(const std::string& path) override; + bool exists(const std::string& path) override; + std::vector list(const std::string& path, bool recursive, std::string extension) override; + bool is_directory(const std::string& path) override; + bool is_file(const std::string& path) override; + uint64_t get_file_size(const std::string& path) override; + int64_t get_modified_time(const std::string& path) override; + bool is_valid_path(const std::string& path) override; + std::shared_ptr get_stream(const std::string& path) override; + FilesystemWrapperType get_type() override; + bool remove(const std::string& path) override; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/internal/grpc/storage_grpc_server.hpp b/modyn/storage/include/internal/grpc/storage_grpc_server.hpp new file mode 100644 index 000000000..4abd62728 --- /dev/null +++ b/modyn/storage/include/internal/grpc/storage_grpc_server.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace modyn::storage { + +class StorageGrpcServer { + public: + StorageGrpcServer(const YAML::Node& config, std::atomic* stop_grpc_server, + std::atomic* request_storage_shutdown) + : config_{config}, stop_grpc_server_{stop_grpc_server}, request_storage_shutdown_{request_storage_shutdown} {} + void run(); + void stop() { + SPDLOG_INFO("gRPC Server requesting storage shutdown"); + stop_grpc_server_->store(true); + request_storage_shutdown_->store(true); + } + + private: + YAML::Node config_; + std::atomic* stop_grpc_server_; + std::atomic* request_storage_shutdown_; +}; + +} // namespace modyn::storage \ No newline at end of file diff --git a/modyn/storage/include/internal/grpc/storage_service_impl.hpp b/modyn/storage/include/internal/grpc/storage_service_impl.hpp new file mode 100644 index 000000000..4bc42c269 --- /dev/null +++ b/modyn/storage/include/internal/grpc/storage_service_impl.hpp @@ -0,0 +1,549 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "internal/database/cursor_handler.hpp" +#include "internal/database/storage_database_connection.hpp" +#include "internal/file_wrapper/file_wrapper_utils.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper_utils.hpp" +#include "modyn/utils/utils.hpp" + +// Since grpc > 1.54.2, there are extra semicola and a missing override in +// the external generated header. Since we want to have -Werror and diagnostics +// on our code, we temporarily disable the warnings when importing this generated header. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wextra-semi" +#if defined(__clang__) +// This is only a clang error... +#pragma GCC diagnostic ignored "-Winconsistent-missing-override" +#endif +#include "storage.grpc.pb.h" +#pragma GCC diagnostic pop + +namespace modyn::storage { + +using namespace grpc; + +struct SampleData { + std::vector ids; + std::vector indices; + std::vector labels; +}; + +struct DatasetData { + int64_t dataset_id = -1; + std::string base_path; + FilesystemWrapperType filesystem_wrapper_type = FilesystemWrapperType::INVALID_FSW; + FileWrapperType file_wrapper_type = FileWrapperType::INVALID_FW; + std::string file_wrapper_config; +}; + +class StorageServiceImpl final : public modyn::storage::Storage::Service { + public: + explicit StorageServiceImpl(const YAML::Node& config, uint64_t retrieval_threads = 1) + : Service(), // NOLINT readability-redundant-member-init (we need to call the base constructor) + config_{config}, + retrieval_threads_{retrieval_threads}, + disable_multithreading_{retrieval_threads <= 1}, + storage_database_connection_{config} { + if (!config_["storage"]["sample_batch_size"]) { + SPDLOG_ERROR("No sample_batch_size specified in config.yaml"); + return; + } + sample_batch_size_ = config_["storage"]["sample_batch_size"].as(); + + if (disable_multithreading_) { + SPDLOG_INFO("Multithreading disabled."); + } else { + SPDLOG_INFO("Multithreading enabled."); + } + } + + Status Get(ServerContext* context, const modyn::storage::GetRequest* request, + ServerWriter* writer) override; + Status GetNewDataSince(ServerContext* context, const modyn::storage::GetNewDataSinceRequest* request, + ServerWriter* writer) override; + Status GetDataInInterval(ServerContext* context, const modyn::storage::GetDataInIntervalRequest* request, + ServerWriter* writer) override; + Status CheckAvailability(ServerContext* context, const modyn::storage::DatasetAvailableRequest* request, + modyn::storage::DatasetAvailableResponse* response) override; + Status RegisterNewDataset(ServerContext* context, const modyn::storage::RegisterNewDatasetRequest* request, + modyn::storage::RegisterNewDatasetResponse* response) override; + Status GetCurrentTimestamp(ServerContext* context, const modyn::storage::GetCurrentTimestampRequest* request, + modyn::storage::GetCurrentTimestampResponse* response) override; + Status DeleteDataset(ServerContext* context, const modyn::storage::DatasetAvailableRequest* request, + modyn::storage::DeleteDatasetResponse* response) override; + Status DeleteData(ServerContext* context, const modyn::storage::DeleteDataRequest* request, + modyn::storage::DeleteDataResponse* response) override; + Status GetDataPerWorker(ServerContext* context, const modyn::storage::GetDataPerWorkerRequest* request, + ServerWriter<::modyn::storage::GetDataPerWorkerResponse>* writer) override; + Status GetDatasetSize(ServerContext* context, const modyn::storage::GetDatasetSizeRequest* request, + modyn::storage::GetDatasetSizeResponse* response) override; + + template + Status Get_Impl( // NOLINT (readability-identifier-naming) + ServerContext* /*context*/, const modyn::storage::GetRequest* request, WriterT* writer) { + try { + soci::session session = storage_database_connection_.get_session(); + + // Check if the dataset exists + std::string dataset_name = request->dataset_id(); + const DatasetData dataset_data = get_dataset_data(session, dataset_name); + + SPDLOG_INFO(fmt::format("Received GetRequest for dataset {} (id = {}) with {} keys.", dataset_name, + dataset_data.dataset_id, request->keys_size())); + + if (dataset_data.dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", request->dataset_id()); + session.close(); + return {StatusCode::OK, "Dataset does not exist."}; + } + + const auto keys_size = static_cast(request->keys_size()); + if (keys_size == 0) { + return {StatusCode::OK, "No keys provided."}; + } + + std::vector request_keys; + request_keys.reserve(keys_size); + std::copy(request->keys().begin(), request->keys().end(), std::back_inserter(request_keys)); + + send_sample_data_from_keys(writer, request_keys, dataset_data); + + // sqlite causes memory leaks otherwise + if (session.get_backend_name() != "sqlite3" && session.is_connected()) { + session.close(); + } + + return {StatusCode::OK, "Data retrieved."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in Get: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in Get: {}", e.what())}; + } + } + + template + Status GetNewDataSince_Impl( // NOLINT (readability-identifier-naming) + ServerContext* /*context*/, const modyn::storage::GetNewDataSinceRequest* request, WriterT* writer) { + try { + soci::session session = storage_database_connection_.get_session(); + const int64_t dataset_id = get_dataset_id(session, request->dataset_id()); + if (dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", dataset_id); + session.close(); + return {StatusCode::OK, "Dataset does not exist."}; + } + session.close(); + + const int64_t request_timestamp = request->timestamp(); + + SPDLOG_INFO(fmt::format("Received GetNewDataSince Request for dataset {} (id = {}) with timestamp {}.", + request->dataset_id(), dataset_id, request_timestamp)); + + send_file_ids_and_labels(writer, dataset_id, request_timestamp); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in GetNewDataSince: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in GetNewDataSince: {}", e.what())}; + } + return {StatusCode::OK, "Data retrieved."}; + } + + template + Status GetDataInInterval_Impl( // NOLINT (readability-identifier-naming) + ServerContext* /*context*/, const modyn::storage::GetDataInIntervalRequest* request, WriterT* writer) { + try { + soci::session session = storage_database_connection_.get_session(); + const int64_t dataset_id = get_dataset_id(session, request->dataset_id()); + if (dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", dataset_id); + session.close(); + return {StatusCode::OK, "Dataset does not exist."}; + } + session.close(); + + const int64_t start_timestamp = request->start_timestamp(); + const int64_t end_timestamp = request->end_timestamp(); + + SPDLOG_INFO( + fmt::format("Received GetDataInInterval Request for dataset {} (id = {}) with start = {} and end = {}.", + request->dataset_id(), dataset_id, start_timestamp, end_timestamp)); + + send_file_ids_and_labels(writer, dataset_id, start_timestamp, + end_timestamp); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in GetDataInInterval: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in GetDataInInterval: {}", e.what())}; + } + return {StatusCode::OK, "Data retrieved."}; + } + + template > + void send_sample_data_from_keys(WriterT* writer, const std::vector& request_keys, + const DatasetData& dataset_data) { + // Create mutex to protect the writer from concurrent writes as this is not supported by gRPC + std::mutex writer_mutex; + + if (disable_multithreading_) { + const std::vector::const_iterator begin = request_keys.begin(); // NOLINT (modernize-use-auto) + const std::vector::const_iterator end = request_keys.end(); // NOLINT (modernize-use-auto) + + get_samples_and_send(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_); + + } else { + std::vector::const_iterator, std::vector::const_iterator>> + its_per_thread = get_keys_per_thread(request_keys, retrieval_threads_); + std::vector retrieval_threads_vector(retrieval_threads_); + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + const std::vector::const_iterator begin = its_per_thread[thread_id].first; + const std::vector::const_iterator end = its_per_thread[thread_id].second; + + retrieval_threads_vector[thread_id] = + std::thread(StorageServiceImpl::get_samples_and_send, begin, end, writer, &writer_mutex, + &dataset_data, &config_, sample_batch_size_); + } + + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + if (retrieval_threads_vector[thread_id].joinable()) { + retrieval_threads_vector[thread_id].join(); + } + } + retrieval_threads_vector.clear(); + } + } + + template > + void send_file_ids_and_labels(WriterT* writer, const int64_t dataset_id, const int64_t start_timestamp = -1, + int64_t end_timestamp = -1) { + soci::session session = storage_database_connection_.get_session(); + // TODO(#359): We might want to have a cursor for this as well and iterate over it, since that can also + // return millions of files + const std::vector file_ids = get_file_ids(session, dataset_id, start_timestamp, end_timestamp); + session.close(); + + if (file_ids.empty()) { + return; + } + std::mutex writer_mutex; // We need to protect the writer from concurrent writes as this is not supported by gRPC + const bool force_no_mt = true; + // TODO(#360): Fix multithreaded sample retrieval here + SPDLOG_ERROR("Multithreaded retrieval of new samples is currently broken, disabling..."); + + if (force_no_mt || disable_multithreading_) { + send_sample_id_and_label(writer, &writer_mutex, file_ids.begin(), file_ids.end(), &config_, + dataset_id, sample_batch_size_); + } else { + // Split the number of files over retrieval_threads_ + std::vector::const_iterator, std::vector::const_iterator>> + file_ids_per_thread = get_keys_per_thread(file_ids, retrieval_threads_); + + std::vector retrieval_threads_vector(retrieval_threads_); + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + retrieval_threads_vector[thread_id] = + std::thread(StorageServiceImpl::send_sample_id_and_label, writer, &writer_mutex, + file_ids_per_thread[thread_id].first, file_ids_per_thread[thread_id].second, &config_, + dataset_id, sample_batch_size_); + } + + for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { + if (retrieval_threads_vector[thread_id].joinable()) { + retrieval_threads_vector[thread_id].join(); + } + } + } + } + + template > + static void send_sample_id_and_label(WriterT* writer, // NOLINT (readability-function-cognitive-complexity) + std::mutex* writer_mutex, const std::vector::const_iterator begin, + const std::vector::const_iterator end, const YAML::Node* config, + int64_t dataset_id, int64_t sample_batch_size) { + if (begin >= end) { + return; + } + + const StorageDatabaseConnection storage_database_connection(*config); + soci::session session = storage_database_connection.get_session(); + + const int64_t num_paths = end - begin; + // TODO(#361): Do not hardcode this number + const auto chunk_size = static_cast(1000000); + int64_t num_chunks = num_paths / chunk_size; + if (num_paths % chunk_size != 0) { + ++num_chunks; + } + + for (int64_t i = 0; i < num_chunks; ++i) { + auto start_it = begin + i * chunk_size; + auto end_it = i < num_chunks - 1 ? start_it + chunk_size : end; + + std::vector file_ids(start_it, end_it); + std::string file_placeholders = fmt::format("({})", fmt::join(file_ids, ",")); + + std::vector record_buf; + record_buf.reserve(sample_batch_size); + + const std::string query = fmt::format( + "SELECT samples.sample_id, samples.label, files.updated_at " + "FROM samples INNER JOIN files " + "ON samples.file_id = files.file_id AND samples.dataset_id = files.dataset_id " + "WHERE samples.file_id IN {} AND samples.dataset_id = {} " + "ORDER BY files.updated_at ASC", + file_placeholders, dataset_id); + const std::string cursor_name = fmt::format("cursor_{}_{}", dataset_id, file_ids.at(0)); + CursorHandler cursor_handler(session, storage_database_connection.get_drivername(), query, cursor_name, 3); + + std::vector records; + + while (true) { + ASSERT(static_cast(record_buf.size()) < sample_batch_size, + fmt::format("Should have written records buffer, size = {}", record_buf.size())); + records = cursor_handler.yield_per(sample_batch_size); + + if (records.empty()) { + break; + } + + const uint64_t obtained_records = records.size(); + ASSERT(static_cast(obtained_records) <= sample_batch_size, "Received too many samples"); + + if (static_cast(obtained_records) == sample_batch_size) { + // If we obtained a full buffer, we can emit a response directly + ResponseT response; + for (const auto& record : records) { + response.add_keys(record.id); + response.add_labels(record.column_1); + response.add_timestamps(record.column_2); + } + + /* SPDLOG_INFO("Sending with response_keys = {}, response_labels = {}, records.size = {}", + response.keys_size(), response.labels_size(), records.size()); */ + + records.clear(); + + { + const std::lock_guard lock(*writer_mutex); + writer->Write(response); + } + } else { + // If not, we append to our record buf + record_buf.insert(record_buf.end(), records.begin(), records.end()); + records.clear(); + // If our record buf is big enough, emit a message + if (static_cast(record_buf.size()) >= sample_batch_size) { + ResponseT response; + + // sample_batch_size is signed int... + for (int64_t record_idx = 0; record_idx < sample_batch_size; ++record_idx) { + const SampleRecord& record = record_buf[record_idx]; + response.add_keys(record.id); + response.add_labels(record.column_1); + response.add_timestamps(record.column_2); + } + /*SPDLOG_INFO( + "Sending with response_keys = {}, response_labels = {}, record_buf.size = {} (minus sample_batch_size " + "= " + "{})", + response.keys_size(), response.labels_size(), record_buf.size(), sample_batch_size); */ + + // Now, delete first sample_batch_size elements from vector as we are sending them + record_buf.erase(record_buf.begin(), record_buf.begin() + sample_batch_size); + + // SPDLOG_INFO("New record_buf size = {}", record_buf.size()); + + ASSERT(static_cast(record_buf.size()) < sample_batch_size, + "The record buffer should never have more than 2*sample_batch_size elements!"); + + { + const std::lock_guard lock(*writer_mutex); + writer->Write(response); + } + } + } + } + + cursor_handler.close_cursor(); + + // Iterated over all files, we now need to emit all data from buffer + if (!record_buf.empty()) { + ASSERT(static_cast(record_buf.size()) < sample_batch_size, + fmt::format("We should have written this buffer before! Buffer has {} items.", record_buf.size())); + + ResponseT response; + for (const auto& record : record_buf) { + response.add_keys(record.id); + response.add_labels(record.column_1); + response.add_timestamps(record.column_2); + } + /* SPDLOG_INFO("Sending with response_keys = {}, response_labels = {}, record_buf.size = {}", + response.keys_size(), response.labels_size(), record_buf.size()); */ + record_buf.clear(); + { + const std::lock_guard lock(*writer_mutex); + writer->Write(response); + } + } + } + + // sqlite causes memory leaks otherwise + if (session.get_backend_name() != "sqlite3" && session.is_connected()) { + session.close(); + } + } + + template > + static void send_sample_data_for_keys_and_file( // NOLINT(readability-function-cognitive-complexity) + WriterT* writer, std::mutex& writer_mutex, const std::vector& sample_keys, + const DatasetData& dataset_data, soci::session& session, int64_t /*sample_batch_size*/) { + // Note that we currently ignore the sample batch size here, under the assumption that users do not request more + // keys than this + try { + const uint64_t num_keys = sample_keys.size(); + std::vector sample_labels(num_keys); + std::vector sample_indices(num_keys); + std::vector sample_fileids(num_keys); + const std::string sample_query = fmt::format( + "SELECT label, sample_index, file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN ({}) ORDER " + "BY file_id", + fmt::join(sample_keys, ",")); + session << sample_query, soci::into(sample_labels), soci::into(sample_indices), soci::into(sample_fileids), + soci::use(dataset_data.dataset_id); + + int64_t current_file_id = sample_fileids[0]; + int64_t current_file_start_idx = 0; + std::string current_file_path; + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + + if (current_file_path.empty()) { + SPDLOG_ERROR(fmt::format("Could not obtain full path of file id {} in dataset {}", current_file_id, + dataset_data.dataset_id)); + } + const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config); + auto filesystem_wrapper = + get_filesystem_wrapper(static_cast(dataset_data.filesystem_wrapper_type)); + + auto file_wrapper = + get_file_wrapper(current_file_path, static_cast(dataset_data.file_wrapper_type), + file_wrapper_config_node, filesystem_wrapper); + + for (uint64_t sample_idx = 0; sample_idx < num_keys; ++sample_idx) { + const int64_t& sample_fileid = sample_fileids[sample_idx]; + + if (sample_fileid != current_file_id) { + // 1. Prepare response + const std::vector file_indexes( + sample_indices.begin() + static_cast(current_file_start_idx), + sample_indices.begin() + static_cast(sample_idx)); + std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + data.clear(); + data.shrink_to_fit(); + + modyn::storage::GetResponse response; + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), + sample_keys.begin() + static_cast(sample_idx)); + response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + sample_labels.begin() + static_cast(sample_idx)); + + // 2. Send response + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); + } + + // 3. Update state + current_file_id = sample_fileid; + current_file_path = "", + session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + file_wrapper->set_file_path(current_file_path); + current_file_start_idx = static_cast(sample_idx); + } + } + + // Send leftovers + const std::vector file_indexes(sample_indices.begin() + current_file_start_idx, sample_indices.end()); + const std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); + // Protobuf expects the data as std::string... + std::vector stringified_data; + stringified_data.reserve(data.size()); + for (const std::vector& char_vec : data) { + stringified_data.emplace_back(char_vec.begin(), char_vec.end()); + } + + modyn::storage::GetResponse response; + response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); + response.mutable_keys()->Assign(sample_keys.begin() + current_file_start_idx, sample_keys.end()); + response.mutable_labels()->Assign(sample_labels.begin() + current_file_start_idx, sample_labels.end()); + + { + const std::lock_guard lock(writer_mutex); + writer->Write(response); + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in send_sample_data_for_keys_and_file: {}", e.what()); + throw; + } + } + + template + static void get_samples_and_send(const std::vector::const_iterator begin, + const std::vector::const_iterator end, WriterT* writer, + std::mutex* writer_mutex, const DatasetData* dataset_data, const YAML::Node* config, + int64_t sample_batch_size) { + if (begin >= end) { + return; + } + const StorageDatabaseConnection storage_database_connection(*config); + soci::session session = storage_database_connection.get_session(); + const std::vector sample_keys(begin, end); + send_sample_data_for_keys_and_file(writer, *writer_mutex, sample_keys, *dataset_data, session, + sample_batch_size); + session.close(); + } + + static std::tuple get_partition_for_worker(int64_t worker_id, int64_t total_workers, + int64_t total_num_elements); + static int64_t get_number_of_samples_in_file(int64_t file_id, soci::session& session, int64_t dataset_id); + + static std::vector get_file_ids(soci::session& session, int64_t dataset_id, int64_t start_timestamp = -1, + int64_t end_timestamp = -1); + static uint64_t get_file_count(soci::session& session, int64_t dataset_id, int64_t start_timestamp, + int64_t end_timestamp); + static std::vector get_file_ids_given_number_of_files(soci::session& session, int64_t dataset_id, + int64_t start_timestamp, int64_t end_timestamp, + uint64_t number_of_files); + static int64_t get_dataset_id(soci::session& session, const std::string& dataset_name); + static std::vector get_file_ids_for_samples(const std::vector& request_keys, int64_t dataset_id, + soci::session& session); + static std::vector::const_iterator, std::vector::const_iterator>> + get_keys_per_thread(const std::vector& keys, uint64_t threads); + static std::vector get_samples_corresponding_to_file(int64_t file_id, int64_t dataset_id, + const std::vector& request_keys, + soci::session& session); + static DatasetData get_dataset_data(soci::session& session, std::string& dataset_name); + + private: + YAML::Node config_; + int64_t sample_batch_size_ = 10000; + uint64_t retrieval_threads_; + bool disable_multithreading_; + StorageDatabaseConnection storage_database_connection_; +}; +} // namespace modyn::storage diff --git a/modyn/storage/include/storage_server.hpp b/modyn/storage/include/storage_server.hpp new file mode 100644 index 000000000..37107898d --- /dev/null +++ b/modyn/storage/include/storage_server.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include + +#include "internal/file_watcher/file_watcher_watchdog.hpp" +#include "internal/grpc/storage_grpc_server.hpp" +#include "yaml-cpp/yaml.h" + +namespace modyn::storage { +class StorageServer { + public: + explicit StorageServer(const std::string& config_file) + : config_{YAML::LoadFile(config_file)}, + connection_{config_}, + file_watcher_watchdog_{config_, &stop_file_watcher_watchdog_, &storage_shutdown_requested_}, + grpc_server_{config_, &stop_grpc_server_, &storage_shutdown_requested_} {} + void run(); + + private: + YAML::Node config_; + StorageDatabaseConnection connection_; + std::atomic storage_shutdown_requested_ = false; + std::atomic stop_file_watcher_watchdog_ = false; + std::atomic stop_grpc_server_ = false; + FileWatcherWatchdog file_watcher_watchdog_; + StorageGrpcServer grpc_server_; +}; +} // namespace modyn::storage diff --git a/modyn/storage/internal/__init__.py b/modyn/storage/internal/__init__.py deleted file mode 100644 index 4e54d865f..000000000 --- a/modyn/storage/internal/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Storage module. - -The storage module contains all classes and functions related to the storage and retrieval of data. -""" - -import os - -files = os.listdir(os.path.dirname(__file__)) -files.remove("__init__.py") -__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/internal/database/__init__.py b/modyn/storage/internal/database/__init__.py deleted file mode 100644 index baeb8ee96..000000000 --- a/modyn/storage/internal/database/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""This package contains the database classes for the internal storage module. - -The database classes are used to abstract the database operations. -This allows the storage module to be used with different databases. -""" - -import os - -files = os.listdir(os.path.dirname(__file__)) -files.remove("__init__.py") -__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/internal/database/models/__init__.py b/modyn/storage/internal/database/models/__init__.py deleted file mode 100644 index 493d0dfd1..000000000 --- a/modyn/storage/internal/database/models/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -"""This package contains all the ORM models for the database. - -The models are used to abstract the database operations. -This allows the storage module to be used with different databases. -""" -import os - -from .dataset import Dataset # noqa: F401 -from .file import File # noqa: F401 -from .sample import Sample # noqa: F401 - -files = os.listdir(os.path.dirname(__file__)) -files.remove("__init__.py") -__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/internal/database/models/dataset.py b/modyn/storage/internal/database/models/dataset.py deleted file mode 100644 index 81611f2b5..000000000 --- a/modyn/storage/internal/database/models/dataset.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Dataset model.""" - -from modyn.storage.internal.database.storage_base import StorageBase -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType -from sqlalchemy import BigInteger, Boolean, Column, Enum, Integer, String - - -class Dataset(StorageBase): - """Dataset model.""" - - __tablename__ = "datasets" - # See https://docs.sqlalchemy.org/en/13/core/metadata.html?highlight=extend_existing#sqlalchemy.schema.Table.params.extend_existing # noqa: E501 - __table_args__ = {"extend_existing": True} - dataset_id = Column("dataset_id", Integer, primary_key=True) - name = Column(String(80), index=True, unique=True, nullable=False) - description = Column(String(120), unique=False, nullable=True) - version = Column(String(80), unique=False, nullable=True) - filesystem_wrapper_type = Column(Enum(FilesystemWrapperType), nullable=False) - file_wrapper_type = Column(Enum(FileWrapperType), nullable=False) - base_path = Column(String(120), unique=False, nullable=False) - file_wrapper_config = Column(String(240), unique=False, nullable=True) - last_timestamp = Column(BigInteger, unique=False, nullable=False) - ignore_last_timestamp = Column(Boolean, unique=False, nullable=False, default=False) - file_watcher_interval = Column(BigInteger, unique=False, nullable=False, default=5) - - def __repr__(self) -> str: - """Return string representation.""" - return f"" diff --git a/modyn/storage/internal/database/models/file.py b/modyn/storage/internal/database/models/file.py deleted file mode 100644 index 273d79333..000000000 --- a/modyn/storage/internal/database/models/file.py +++ /dev/null @@ -1,27 +0,0 @@ -"""File model.""" - -from modyn.storage.internal.database.storage_base import StorageBase -from sqlalchemy import BigInteger, Column, ForeignKey, Integer, String -from sqlalchemy.dialects import sqlite -from sqlalchemy.orm import relationship - -BIGINT = BigInteger().with_variant(sqlite.INTEGER(), "sqlite") - - -class File(StorageBase): - """File model.""" - - __tablename__ = "files" - # See https://docs.sqlalchemy.org/en/13/core/metadata.html?highlight=extend_existing#sqlalchemy.schema.Table.params.extend_existing # noqa: E501 - __table_args__ = {"extend_existing": True} - file_id = Column("file_id", BIGINT, autoincrement=True, primary_key=True) - dataset_id = Column(Integer, ForeignKey("datasets.dataset_id"), nullable=False, index=True) - dataset = relationship("Dataset") - path = Column(String(120), unique=False, nullable=False) - created_at = Column(BigInteger, nullable=False) - updated_at = Column(BigInteger, nullable=False, index=True) - number_of_samples = Column(Integer, nullable=False) - - def __repr__(self) -> str: - """Return string representation.""" - return f"" diff --git a/modyn/storage/internal/database/models/sample.py b/modyn/storage/internal/database/models/sample.py deleted file mode 100644 index 440ee73e2..000000000 --- a/modyn/storage/internal/database/models/sample.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Sample model.""" - -from typing import Any, Optional - -from modyn.database import PartitionByMeta -from modyn.storage.internal.database.storage_base import StorageBase -from sqlalchemy import BigInteger, Column, Integer -from sqlalchemy.dialects import sqlite -from sqlalchemy.engine import Engine -from sqlalchemy.orm.session import Session -from sqlalchemy.schema import PrimaryKeyConstraint - -BIGINT = BigInteger().with_variant(sqlite.INTEGER(), "sqlite") - - -class SampleMixin: - # Note that we have a composite primary key in the general case because partitioning on the dataset - # requires the dataset_id to be part of the PK. - # Logically, sample_id is sufficient for the PK. - sample_id = Column("sample_id", BIGINT, autoincrement=True, primary_key=True) - dataset_id = Column(Integer, nullable=False, primary_key=True) - file_id = Column(Integer, nullable=True) - # This should not be null but we remove the integrity check in favor of insertion performance. - index = Column(BigInteger, nullable=True) - label = Column(BigInteger, nullable=True) - - -class Sample( - SampleMixin, - StorageBase, - metaclass=PartitionByMeta, - partition_by="dataset_id", # type: ignore - partition_type="LIST", # type: ignore -): - """Sample model.""" - - __tablename__ = "samples" - - def __repr__(self) -> str: - """Return string representation.""" - return f"" - - @staticmethod - def ensure_pks_correct(session: Session) -> None: - if session.bind.dialect.name == "sqlite": - # sqllite does not support AUTOINCREMENT on composite PKs - # As it also does not support partitioning, in case of sqlite, we need to update the model - # to only have sample_id as PK, which has no further implications. dataset_id is only part of the PK - # in the first case as that is required by postgres for partitioning. - # Updating the model at runtime requires hacking the sqlalchemy internals - # and what exactly do change took me a while to figure out. - # This is not officially supported by sqlalchemy. - # Basically, we need to change all the things where dataset_id is part of the PK - # Simply writing Sample.dataset_id.primary_key = False or - # Sample.dataset_id = Column(..., primary_key=False) does not work at runtime. - # We first need to mark the column as non primary key - # and then update the constraint (on the Table object, used to create SQL operations) - # Last, we have to update the mapper - # (used during query generation, needs to be synchronized to the Table, otherwise we get an error) - if Sample.__table__.c.dataset_id.primary_key: - Sample.__table__.c.dataset_id.primary_key = False - Sample.__table__.primary_key = PrimaryKeyConstraint(Sample.sample_id) - Sample.__mapper__.primary_key = Sample.__mapper__.primary_key[0:1] - - @staticmethod - def add_dataset( - dataset_id: int, session: Session, engine: Engine, hash_partition_modulus: int = 8, unlogged: bool = True - ) -> None: - partition_stmt = f"FOR VALUES IN ({dataset_id})" - partition_suffix = f"_did{dataset_id}" - dataset_partition = Sample._create_partition( - Sample, - partition_suffix, - partition_stmt=partition_stmt, - subpartition_by="sample_id", - subpartition_type="HASH", - session=session, - engine=engine, - unlogged=unlogged, - ) - - if dataset_partition is None: - return # partitoning disabled - - # Create partitions for sample key hash - for i in range(hash_partition_modulus): - partition_suffix = f"_part{i}" - partition_stmt = f"FOR VALUES WITH (modulus {hash_partition_modulus}, remainder {i})" - Sample._create_partition( - dataset_partition, - partition_suffix, - partition_stmt=partition_stmt, - subpartition_by=None, - subpartition_type=None, - session=session, - engine=engine, - unlogged=unlogged, - ) - - @staticmethod - def _create_partition( - instance: Any, # This is the class itself - partition_suffix: str, - partition_stmt: str, - subpartition_by: Optional[str], - subpartition_type: Optional[str], - session: Session, - engine: Engine, - unlogged: bool, - ) -> Optional[PartitionByMeta]: - """Create a partition for the Sample table.""" - if session.bind.dialect.name == "sqlite": - return None - - # Create partition - partition = instance.create_partition( - partition_suffix, - partition_stmt=partition_stmt, - subpartition_by=subpartition_by, - subpartition_type=subpartition_type, - unlogged=unlogged, - ) - - #  Create table - Sample.metadata.create_all(engine, [partition.__table__]) - - return partition diff --git a/modyn/storage/internal/database/storage_base.py b/modyn/storage/internal/database/storage_base.py deleted file mode 100644 index 291c9ddc8..000000000 --- a/modyn/storage/internal/database/storage_base.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Base model.""" - -from sqlalchemy.orm import DeclarativeBase - - -class StorageBase(DeclarativeBase): - pass diff --git a/modyn/storage/internal/database/storage_database_connection.py b/modyn/storage/internal/database/storage_database_connection.py deleted file mode 100644 index 956206a3c..000000000 --- a/modyn/storage/internal/database/storage_database_connection.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Database connection context manager.""" - -from __future__ import annotations - -import logging - -from modyn.database.abstract_database_connection import AbstractDatabaseConnection -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.database.storage_base import StorageBase -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType -from sqlalchemy import exc - -logger = logging.getLogger(__name__) - - -class StorageDatabaseConnection(AbstractDatabaseConnection): - """Database connection context manager.""" - - def __init__(self, modyn_config: dict) -> None: - """Initialize the database connection. - - Args: - modyn_config (dict): Configuration of the modyn module. - """ - super().__init__(modyn_config) - self.drivername: str = self.modyn_config["storage"]["database"]["drivername"] - self.username: str = self.modyn_config["storage"]["database"]["username"] - self.password: str = self.modyn_config["storage"]["database"]["password"] - self.host: str = self.modyn_config["storage"]["database"]["host"] - self.port: int = self.modyn_config["storage"]["database"]["port"] - self.database: str = self.modyn_config["storage"]["database"]["database"] - self.hash_partition_modulus: int = ( - self.modyn_config["storage"]["database"]["hash_partition_modulus"] - if "hash_partition_modulus" in self.modyn_config["storage"]["database"] - else 8 - ) - self.sample_table_unlogged: bool = ( - self.modyn_config["storage"]["database"]["sample_table_unlogged"] - if "sample_table_unlogged" in self.modyn_config["storage"]["database"] - else True - ) - - def __enter__(self) -> StorageDatabaseConnection: - """Create the engine and session. - - Returns: - DatabaseConnection: DatabaseConnection. - """ - super().__enter__() - return self - - def create_tables(self) -> None: - """ - Create all tables. Each table is represented by a class. - - All classes that inherit from Base are mapped to tables - which are created in the database if they do not exist. - - The metadata is a collection of Table objects that inherit from Base and their associated - schema constructs (such as Column objects, ForeignKey objects, and so on). - """ - Sample.ensure_pks_correct(self.session) - StorageBase.metadata.create_all(self.engine) - - def add_dataset( - self, - name: str, - base_path: str, - filesystem_wrapper_type: FilesystemWrapperType, - file_wrapper_type: FileWrapperType, - description: str, - version: str, - file_wrapper_config: str, - ignore_last_timestamp: bool = False, - file_watcher_interval: int = 5, - ) -> bool: - """ - Add dataset to database. - - If dataset with name already exists, it is updated. - """ - try: - if self.session.query(Dataset).filter(Dataset.name == name).first() is not None: - logger.info(f"Dataset with name {name} exists.") - self.session.query(Dataset).filter(Dataset.name == name).update( - { - "base_path": base_path, - "filesystem_wrapper_type": filesystem_wrapper_type, - "file_wrapper_type": file_wrapper_type, - "description": description, - "version": version, - "file_wrapper_config": file_wrapper_config, - "ignore_last_timestamp": ignore_last_timestamp, - "file_watcher_interval": file_watcher_interval, - } - ) - else: - logger.info(f"Dataset with name {name} does not exist.") - dataset = Dataset( - name=name, - base_path=base_path, - filesystem_wrapper_type=filesystem_wrapper_type, - file_wrapper_type=file_wrapper_type, - description=description, - version=version, - file_wrapper_config=file_wrapper_config, - last_timestamp=-1, # Set to -1 as this is a new dataset - ignore_last_timestamp=ignore_last_timestamp, - file_watcher_interval=file_watcher_interval, - ) - self.session.add(dataset) - self.session.commit() - except exc.SQLAlchemyError as exception: - logger.error(f"Error adding dataset: {exception}") - self.session.rollback() - return False - return True - - def delete_dataset(self, name: str) -> bool: - """Delete dataset from database.""" - try: - self.session.query(Sample).filter( - Sample.file_id.in_(self.session.query(File.file_id).join(Dataset).filter(Dataset.name == name)) - ).delete(synchronize_session="fetch") - self.session.query(File).filter( - File.dataset_id.in_(self.session.query(Dataset.dataset_id).filter(Dataset.name == name)) - ).delete(synchronize_session="fetch") - self.session.query(Dataset).filter(Dataset.name == name).delete(synchronize_session="fetch") - self.session.commit() - except exc.SQLAlchemyError as exception: - logger.error(f"Error deleting dataset: {exception}") - self.session.rollback() - return False - return True - - def add_sample_dataset(self, dataset_id: int) -> None: - """Add a new dataset to the samples table. - - This method creates a new partitions for the dataset. - - Args: - dataset_id (int): Id of the dataset - """ - Sample.add_dataset( - dataset_id, self.session, self.engine, self.hash_partition_modulus, self.sample_table_unlogged - ) diff --git a/modyn/storage/internal/database/storage_database_utils.py b/modyn/storage/internal/database/storage_database_utils.py deleted file mode 100644 index bffa372a9..000000000 --- a/modyn/storage/internal/database/storage_database_utils.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Storage database utilities.""" - -import json -import logging - -from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType, InvalidFileWrapperTypeException -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import ( - FilesystemWrapperType, - InvalidFilesystemWrapperTypeException, -) -from modyn.utils import dynamic_module_import - -logger = logging.getLogger(__name__) - - -def get_filesystem_wrapper(filesystem_wrapper_type: FilesystemWrapperType, base_path: str) -> AbstractFileSystemWrapper: - """Get the filesystem wrapper. - - Args: - filesystem_wrapper_type (FilesystemWrapperType): filesystem wrapper type - base_path (str): base path of the filesystem wrapper - - Raises: - InvalidFilesystemWrapperTypeException: Invalid filesystem wrapper type. - - Returns: - AbstractFileSystemWrapper: filesystem wrapper - """ - if not isinstance(filesystem_wrapper_type, FilesystemWrapperType): - raise InvalidFilesystemWrapperTypeException("Invalid filesystem wrapper type.") - filesystem_wrapper_module = dynamic_module_import( - f"modyn.storage.internal.filesystem_wrapper.{filesystem_wrapper_type.value}" - ) - filesystem_wrapper = getattr(filesystem_wrapper_module, f"{filesystem_wrapper_type.name}") - return filesystem_wrapper(base_path) - - -def get_file_wrapper( - file_wrapper_type: FileWrapperType, - path: str, - file_wrapper_config: str, - filesystem_wrapper: AbstractFileSystemWrapper, -) -> AbstractFileWrapper: - """Get the file wrapper. - - Args: - file_wrapper_type (FileWrapperType): file wrapper type - path (str): path of the file wrapper - file_wrapper_config (str): file wrapper configuration as json string. - - - Raises: - InvalidFileWrapperTypeException: Invalid file wrapper type. - - Returns: - AbstractFileWrapper: file wrapper - """ - if not isinstance(file_wrapper_type, FileWrapperType): - raise InvalidFileWrapperTypeException("Invalid file wrapper type.") - file_wrapper_config = json.loads(file_wrapper_config) - file_wrapper_module = dynamic_module_import(f"modyn.storage.internal.file_wrapper.{file_wrapper_type.value}") - file_wrapper = getattr(file_wrapper_module, f"{file_wrapper_type.name}") - return file_wrapper(path, file_wrapper_config, filesystem_wrapper) diff --git a/modyn/storage/internal/file_watcher/__init__.py b/modyn/storage/internal/file_watcher/__init__.py deleted file mode 100644 index dfda6853e..000000000 --- a/modyn/storage/internal/file_watcher/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Storage module. - -The storage module contains all classes and functions related to the storage's NewFileWatcher -""" - -import os - -files = os.listdir(os.path.dirname(__file__)) -files.remove("__init__.py") -__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/internal/file_watcher/new_file_watcher.py b/modyn/storage/internal/file_watcher/new_file_watcher.py deleted file mode 100644 index 16a2c4541..000000000 --- a/modyn/storage/internal/file_watcher/new_file_watcher.py +++ /dev/null @@ -1,457 +0,0 @@ -"""New file watcher.""" - -import io -import itertools -import json -import logging -import multiprocessing as mp -import os -import pathlib -import platform -import time -from typing import Any, Optional - -import pandas as pd -from modyn.common.benchmark import Stopwatch -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.database.storage_database_utils import get_file_wrapper, get_filesystem_wrapper -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper -from modyn.utils import current_time_millis -from sqlalchemy import exc -from sqlalchemy.orm import exc as orm_exc -from sqlalchemy.orm.session import Session - -logger = logging.getLogger(__name__) - - -class NewFileWatcher: - """New file watcher. - - This class is responsible for watching all the filesystems of the datasets for new files. If a new file is found, it - will be added to the database. - """ - - def __init__( - self, modyn_config: dict, dataset_id: int, should_stop: Any - ): # See https://github.com/python/typeshed/issues/8799 - """Initialize the new file watcher. - - Args: - modyn_config (dict): Configuration of the modyn module. - should_stop (Any): Value that indicates if the new file watcher should stop. - """ - self.modyn_config = modyn_config - self.__should_stop = should_stop - self.__dataset_id = dataset_id - - self._insertion_threads = modyn_config["storage"]["insertion_threads"] - self._sample_dbinsertion_batchsize: int = ( - self.modyn_config["storage"]["sample_dbinsertion_batchsize"] - if "sample_dbinsertion_batchsize" in self.modyn_config["storage"] - else 1000000 - ) - - self._dump_measurements: bool = ( - self.modyn_config["storage"]["dump_performance_measurements"] - if "dump_performance_measurements" in self.modyn_config["storage"] - else False - ) - - self._force_fallback_insert: bool = ( - self.modyn_config["storage"]["force_fallback_insert"] - if "force_fallback_insert" in self.modyn_config["storage"] - else False - ) - - self._is_test = "PYTEST_CURRENT_TEST" in os.environ - self._is_mac = platform.system() == "Darwin" - self._disable_mt = self._insertion_threads <= 0 - - # Initialize dataset partition on Sample table - with StorageDatabaseConnection(self.modyn_config) as database: - database.add_sample_dataset(self.__dataset_id) - - def _seek(self, storage_database_connection: StorageDatabaseConnection, dataset: Dataset) -> None: - """Seek the filesystem for all the datasets for new files and add them to the database. - - If last timestamp is not ignored, the last timestamp of the dataset will be used to only - seek for files that have a timestamp that is equal or greater than the last timestamp. - """ - if dataset is None: - logger.warning( - f"Dataset {self.__dataset_id} not found. Shutting down file watcher for dataset {self.__dataset_id}." - ) - self.__should_stop.value = True - return - session = storage_database_connection.session - try: - logger.debug( - f"Seeking for files in dataset {dataset.dataset_id} with a timestamp that \ - is equal or greater than {dataset.last_timestamp}" - ) - self._seek_dataset(session, dataset) - last_timestamp = ( - session.query(File.updated_at) - .filter(File.dataset_id == dataset.dataset_id) - .order_by(File.updated_at.desc()) - .first() - ) - if last_timestamp is not None: - session.query(Dataset).filter(Dataset.dataset_id == dataset.dataset_id).update( - {"last_timestamp": last_timestamp[0]} - ) - session.commit() - except orm_exc.ObjectDeletedError as error: - # If the dataset was deleted, we should stop the file watcher and delete all the - # orphaned files and samples - logger.warning( - f"Dataset {self.__dataset_id} was deleted. Shutting down " - + f"file watcher for dataset {self.__dataset_id}. Error: {error}" - ) - session.rollback() - storage_database_connection.delete_dataset(dataset.name) - self.__should_stop.value = True - - def _seek_dataset(self, session: Session, dataset: Dataset) -> None: - """Seek the filesystem for a dataset for new files and add them to the database. - - If last timestamp is not ignored, the last timestamp of the dataset will be used to - only seek for files that have a timestamp that is equal or greater than the last timestamp. - - Args: - session (Session): Database session. - dataset (Dataset): Dataset to seek. - """ - filesystem_wrapper = get_filesystem_wrapper(dataset.filesystem_wrapper_type, dataset.base_path) - - if filesystem_wrapper.exists(dataset.base_path): - if filesystem_wrapper.isdir(dataset.base_path): - self._update_files_in_directory( - filesystem_wrapper, - dataset.file_wrapper_type, - dataset.base_path, - dataset.last_timestamp, - session, - dataset, - ) - else: - logger.critical(f"Path {dataset.base_path} is not a directory.") - else: - logger.warning(f"Path {dataset.base_path} does not exist.") - - def _get_datasets(self, session: Session) -> list[Dataset]: - """Get all datasets.""" - datasets: Optional[list[Dataset]] = session.query(Dataset).all() - - if datasets is None or len(datasets) == 0: - logger.warning("No datasets found.") - return [] - - return datasets - - @staticmethod - def _file_unknown(session: Session, file_path: str) -> bool: - """Check if a file is unknown. - - TODO (#147): This is a very inefficient way to check if a file is unknown. It should be replaced - by a more efficient method. - """ - return session.query(File).filter(File.path == file_path).first() is None - - @staticmethod - def _postgres_copy_insertion( - process_id: int, dataset_id: int, file_dfs: list[pd.DataFrame], time_spent: dict, session: Session - ) -> None: - stopwatch = Stopwatch() - - stopwatch.start("session_setup") - - conn = session.connection().engine.raw_connection() - cursor = conn.cursor() - - table_name = f"samples__did{dataset_id}" - table_columns = "(dataset_id,file_id,index,label)" - cmd = f"COPY {table_name}{table_columns} FROM STDIN WITH (FORMAT CSV, HEADER FALSE)" - - logger.debug(f"[Process {process_id}] Dumping CSV in buffer.") - stopwatch.stop() - - stopwatch.start("csv_creation") - output = io.StringIO() - for file_df in file_dfs: - file_df.to_csv( - output, sep=",", header=False, index=False, columns=["dataset_id", "file_id", "index", "label"] - ) - - output.seek(0) - stopwatch.stop() - - stopwatch.start("db_insertion") - logger.debug(f"[Process {process_id}] Copying to DB.") - cursor.copy_expert(cmd, output) - conn.commit() - stopwatch.stop() - - time_spent.update(stopwatch.measurements) - - @staticmethod - def _fallback_copy_insertion( - process_id: int, dataset_id: int, file_dfs: list[pd.DataFrame], time_spent: dict, session: Session - ) -> None: - del process_id - del dataset_id - stopwatch = Stopwatch() - - stopwatch.start("dict_creation") - for file_df in file_dfs: - file_df["sample_id"] = None - - data = list(itertools.chain.from_iterable([file_df.to_dict("records") for file_df in file_dfs])) - - stopwatch.stop() - - stopwatch.start("db_insertion") - session.bulk_insert_mappings(Sample, data) - session.commit() - stopwatch.stop() - - time_spent.update(stopwatch.measurements) - - # pylint: disable=too-many-locals,too-many-statements - - @staticmethod - def _handle_file_paths( - process_id: int, - sample_dbinsertion_batchsize: int, - dump_measurements: bool, - force_fallback_inserts: bool, - file_paths: list[str], - modyn_config: dict, - data_file_extension: str, - filesystem_wrapper: AbstractFileSystemWrapper, - file_wrapper_type: str, - timestamp: int, - dataset_name: str, - dataset_id: int, - session: Optional[Session], # When using multithreading, we cannot pass the session, hence it is Optional - ) -> None: - """Given a list of paths (in terms of a Modyn FileSystem) to files, - check whether there are any new files and if so, add all samples from these files into the DB.""" - - assert sample_dbinsertion_batchsize > 0, "Invalid sample_dbinsertion_batchsize" - - db_connection: Optional[StorageDatabaseConnection] = None - stopwatch = Stopwatch() - - if session is None: # Multithreaded - db_connection = StorageDatabaseConnection(modyn_config) - db_connection.setup_connection() - session = db_connection.session - - insertion_func = NewFileWatcher._fallback_copy_insertion - - if session.bind.dialect.name == "postgresql": - insertion_func = NewFileWatcher._postgres_copy_insertion - - if force_fallback_inserts: # Needs to come last - insertion_func = NewFileWatcher._fallback_copy_insertion - - dataset: Dataset = session.query(Dataset).filter(Dataset.name == dataset_name).first() - - def check_valid_file(file_path: str) -> bool: - path_obj = pathlib.Path(file_path) - if path_obj.suffix != data_file_extension: - return False - if ( - dataset.ignore_last_timestamp or filesystem_wrapper.get_modified(file_path) >= timestamp - ) and NewFileWatcher._file_unknown(session, file_path): - return True - - return False - - valid_files = list(filter(check_valid_file, file_paths)) - - file_dfs = [] - current_len = 0 - - for num_file, file_path in enumerate(valid_files): - stopwatch.start("init", resume=True) - - file_wrapper = get_file_wrapper( - file_wrapper_type, file_path, dataset.file_wrapper_config, filesystem_wrapper - ) - number_of_samples = file_wrapper.get_number_of_samples() - logger.debug( - f"[Process {process_id}] Found new, unknown file: {file_path} with {number_of_samples} samples." - ) - - stopwatch.stop() - stopwatch.start("file_creation", resume=True) - - try: - file: File = File( - dataset=dataset, - path=file_path, - created_at=filesystem_wrapper.get_created(file_path), - updated_at=filesystem_wrapper.get_modified(file_path), - number_of_samples=number_of_samples, - ) - session.add(file) - session.commit() - except exc.SQLAlchemyError as exception: - logger.critical(f"[Process {process_id}] Could not create file {file_path} in database: {exception}") - session.rollback() - continue - - file_id = file.file_id - logger.info( - f"[Process {process_id}] Extracting and inserting samples for file {file_path} (id = {file_id})" - ) - - stopwatch.stop() - - stopwatch.start("label_extraction", resume=True) - labels = file_wrapper.get_all_labels() - stopwatch.stop() - - logger.debug( - f"[Process {process_id}] Labels extracted in" - + f" {round(stopwatch.measurements['label_extraction'] / 1000, 2)}s." - ) - - stopwatch.start("df_creation", resume=True) - - file_df = pd.DataFrame.from_dict({"dataset_id": dataset_id, "file_id": file_id, "label": labels}) - file_df["index"] = range(len(file_df)) - file_dfs.append(file_df) - current_len += len(file_df) - - stopwatch.stop() - insertion_func_measurements: dict[str, int] = {} - - if current_len >= sample_dbinsertion_batchsize or num_file == len(valid_files) - 1: - logger.debug(f"[Process {process_id}] Inserting {current_len} samples.") - stopwatch.start("insertion_func", resume=True) - - insertion_func(process_id, dataset_id, file_dfs, insertion_func_measurements, session) - - stopwatch.stop() - - logger.debug( - f"[Process {process_id}] Inserted {current_len} samples in" - + f" {round((stopwatch.measurements['insertion_func']) / 1000, 2)}s." - ) - - stopwatch.start("cleanup", resume=True) - current_len = 0 - file_dfs.clear() - stopwatch.stop() - - if dump_measurements and len(valid_files) > 0: - measurements = {**stopwatch.measurements, **insertion_func_measurements} - with open( - f"/tmp/modyn_{current_time_millis()}_process{process_id}_stats.json", "w", encoding="utf-8" - ) as statsfile: - json.dump(measurements, statsfile) - - if db_connection is not None: - db_connection.terminate_connection() - - def _update_files_in_directory( - self, - filesystem_wrapper: AbstractFileSystemWrapper, - file_wrapper_type: str, - path: str, - timestamp: int, - session: Session, - dataset: Dataset, - ) -> None: - """Recursively get all files in a directory. - - Get all files that have a timestamp that is equal or greater than the given timestamp.""" - if not filesystem_wrapper.isdir(path): - logger.critical(f"Path {path} is not a directory.") - return - - data_file_extension = json.loads(dataset.file_wrapper_config)["file_extension"] - file_paths = filesystem_wrapper.list(path, recursive=True) - stopwatch = Stopwatch() - - assert self.__dataset_id == dataset.dataset_id - - if self._disable_mt or (self._is_test and self._is_mac): - NewFileWatcher._handle_file_paths( - -1, - self._sample_dbinsertion_batchsize, - self._dump_measurements, - self._force_fallback_insert, - file_paths, - self.modyn_config, - data_file_extension, - filesystem_wrapper, - file_wrapper_type, - timestamp, - dataset.name, - self.__dataset_id, - session, - ) - return - - stopwatch.start("processes") - - files_per_proc = int(len(file_paths) / self._insertion_threads) - processes: list[mp.Process] = [] - for i in range(self._insertion_threads): - start_idx = i * files_per_proc - end_idx = start_idx + files_per_proc if i < self._insertion_threads - 1 else len(file_paths) - paths = file_paths[start_idx:end_idx] - - if len(paths) > 0: - proc = mp.Process( - target=NewFileWatcher._handle_file_paths, - args=( - i, - self._sample_dbinsertion_batchsize, - self._dump_measurements, - self._force_fallback_insert, - paths, - self.modyn_config, - data_file_extension, - filesystem_wrapper, - file_wrapper_type, - timestamp, - dataset.name, - self.__dataset_id, - None, - ), - ) - proc.start() - processes.append(proc) - - for proc in processes: - proc.join() - - runtime = round(stopwatch.stop() / 1000, 2) - if runtime > 5: - logger.debug(f"Processes finished running in in {runtime}s.") - - def run(self) -> None: - """Run the dataset watcher.""" - logger.info("Starting dataset watcher.") - with StorageDatabaseConnection(self.modyn_config) as database: - while not self.__should_stop.value: - dataset = database.session.query(Dataset).filter(Dataset.dataset_id == self.__dataset_id).first() - self._seek(database, dataset) - time.sleep(dataset.file_watcher_interval) - - -def run_new_file_watcher(modyn_config: dict, dataset_id: int, should_stop: Any) -> None: - """Run the file watcher for a dataset. - - Args: - dataset_id (int): Dataset id. - should_stop (Value): Value to check if the file watcher should stop. - """ - file_watcher = NewFileWatcher(modyn_config, dataset_id, should_stop) - file_watcher.run() diff --git a/modyn/storage/internal/file_watcher/new_file_watcher_watch_dog.py b/modyn/storage/internal/file_watcher/new_file_watcher_watch_dog.py deleted file mode 100644 index 9044122d9..000000000 --- a/modyn/storage/internal/file_watcher/new_file_watcher_watch_dog.py +++ /dev/null @@ -1,108 +0,0 @@ -import logging -import time -from ctypes import c_bool -from multiprocessing import Process, Value -from typing import Any - -from modyn.storage.internal.database.models import Dataset -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.file_watcher.new_file_watcher import run_new_file_watcher - -logger = logging.getLogger(__name__) - - -class NewFileWatcherWatchDog: - def __init__(self, modyn_config: dict, should_stop: Any): # See https://github.com/python/typeshed/issues/8799 - """Initialize the new file watcher watch dog. - - Args: - modyn_config (dict): Configuration of the modyn module. - should_stop (Any): Value that indicates if the new file watcher should stop. - """ - self.modyn_config = modyn_config - self.__should_stop = should_stop - self._file_watcher_processes: dict[int, tuple[Process, Any, int]] = {} - - def _watch_file_watcher_processes(self) -> None: - """Manage the file watchers. - - This method will check if there are file watchers that are not watching a dataset anymore. If that is the case, - the file watcher will be stopped. - """ - with StorageDatabaseConnection(self.modyn_config) as storage_database_connection: - session = storage_database_connection.session - dataset_ids = [dataset.dataset_id for dataset in session.query(Dataset).all()] - dataset_ids_in_file_watcher_processes = list(self._file_watcher_processes.keys()) - for dataset_id in dataset_ids_in_file_watcher_processes: - if dataset_id not in dataset_ids: - logger.debug(f"Stopping file watcher for dataset {dataset_id}") - self._stop_file_watcher_process(dataset_id) - - for dataset_id in dataset_ids: - if dataset_id not in self._file_watcher_processes: - logger.debug(f"Starting file watcher for dataset {dataset_id}") - self._start_file_watcher_process(dataset_id) - if self._file_watcher_processes[dataset_id][2] > 3: - logger.debug(f"Stopping file watcher for dataset {dataset_id} because it was restarted too often.") - self._stop_file_watcher_process(dataset_id) - elif not self._file_watcher_processes[dataset_id][0].is_alive(): - logger.debug(f"File watcher for dataset {dataset_id} is not alive. Restarting it.") - self._start_file_watcher_process(dataset_id) - self._file_watcher_processes[dataset_id] = ( - self._file_watcher_processes[dataset_id][0], - self._file_watcher_processes[dataset_id][1], - self._file_watcher_processes[dataset_id][2] + 1, - ) - - def _start_file_watcher_process(self, dataset_id: int) -> None: - """Start a file watcher. - - Args: - dataset_id (int): ID of the dataset that should be watched. - """ - should_stop = Value(c_bool, False) - file_watcher = Process(target=run_new_file_watcher, args=(self.modyn_config, dataset_id, should_stop)) - file_watcher.start() - self._file_watcher_processes[dataset_id] = (file_watcher, should_stop, 0) - - def _stop_file_watcher_process(self, dataset_id: int) -> None: - """Stop a file watcher. - - Args: - dataset_id (int): ID of the dataset that should be watched. - """ - self._file_watcher_processes[dataset_id][1].value = True - i = 0 - while self._file_watcher_processes[dataset_id][0].is_alive() and i < 10: # Wait for the file watcher to stop. - time.sleep(1) - i += 1 - if self._file_watcher_processes[dataset_id][0].is_alive(): - logger.debug(f"File watcher for dataset {dataset_id} is still alive. Terminating it.") - self._file_watcher_processes[dataset_id][0].terminate() - self._file_watcher_processes[dataset_id][0].join() - del self._file_watcher_processes[dataset_id] - - def run(self) -> None: - """Run the new file watcher watchdog. - - Args: - modyn_config (dict): Configuration of the modyn module. - should_stop (Value): Value that indicates if the watcher should stop. - """ - while not self.__should_stop.value: - self._watch_file_watcher_processes() - time.sleep(3) - - for dataset_id in self._file_watcher_processes: - self._stop_file_watcher_process(dataset_id) - - -def run_watcher_watch_dog(modyn_config: dict, should_stop: Any): # type: ignore # See https://github.com/python/typeshed/issues/8799 # noqa: E501 - """Run the new file watcher watch dog. - - Args: - modyn_config (dict): Configuration of the modyn module. - should_stop (Value): Value that indicates if the watcher should stop. - """ - new_file_watcher_watch_dog = NewFileWatcherWatchDog(modyn_config, should_stop) - new_file_watcher_watch_dog.run() diff --git a/modyn/storage/internal/file_wrapper/__init__.py b/modyn/storage/internal/file_wrapper/__init__.py deleted file mode 100644 index a42b88bbd..000000000 --- a/modyn/storage/internal/file_wrapper/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""This module contains the file wrapper classes for the internal storage module. - -The file wrapper classes are used to abstract the file operations. -This allows the storage module to be used with different file formats. -""" -import os - -files = os.listdir(os.path.dirname(__file__)) -files.remove("__init__.py") -__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/internal/file_wrapper/abstract_file_wrapper.py b/modyn/storage/internal/file_wrapper/abstract_file_wrapper.py deleted file mode 100644 index 7795f5990..000000000 --- a/modyn/storage/internal/file_wrapper/abstract_file_wrapper.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Base class for all file wrappers.""" - -from abc import ABC, abstractmethod -from typing import Optional - -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper - - -class AbstractFileWrapper(ABC): - """Base class for all file wrappers.""" - - def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper: AbstractFileSystemWrapper): - """Init file wrapper. - - Args: - file_path (str): Path to file - file_wrapper_config (dict): File wrapper config - """ - self.file_wrapper_type: FileWrapperType = None - self.file_path = file_path - self.file_wrapper_config = file_wrapper_config - self.filesystem_wrapper = filesystem_wrapper - - @abstractmethod - def get_number_of_samples(self) -> int: - """Get the size of the file in number of samples. - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - int: Number of samples - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_samples(self, start: int, end: int) -> list[bytes]: - """Get the samples from the file. - - Args: - start (int): Start index - end (int): End index - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - bytes: Samples - """ - raise NotImplementedError # pragma: no cover - - def get_label(self, index: int) -> Optional[int]: - """Get the label at the given index. - - Args: - index (int): Index - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - int: Label if exists, else None - """ - raise NotImplementedError # pragma: no cover - - def get_all_labels(self) -> list[Optional[int]]: - """Returns a list of all labels of all samples in the file. - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - list[Optional[int]]: List of labels - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_sample(self, index: int) -> bytes: - """Get the sample at the given index. - - Args: - index (int): Index - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - bytes: Sample - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_samples_from_indices(self, indices: list) -> list[bytes]: - """Get the samples at the given indices. - - Args: - indices (list): List of indices - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - bytes: Samples - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def delete_samples(self, indices: list) -> None: - """Delete the samples at the given indices. - - Args: - indices (list): List of indices - - Raises: - NotImplementedError: If the method is not implemented - """ - raise NotImplementedError diff --git a/modyn/storage/internal/file_wrapper/binary_file_wrapper.py b/modyn/storage/internal/file_wrapper/binary_file_wrapper.py deleted file mode 100644 index e5ceb0b0a..000000000 --- a/modyn/storage/internal/file_wrapper/binary_file_wrapper.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Binary file wrapper.""" - -from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper - - -class BinaryFileWrapper(AbstractFileWrapper): - """Binary file wrapper. - - Binary files store raw sample data in a row-oriented format. One file can contain multiple samples. - This wrapper requires that each samples should start with the label followed by its set of features. - Each sample should also have a fixed overall width (in bytes) and a fixed width for the label, - both of which should be provided in the config. The file wrapper is able to read samples by - offsetting the required number of bytes. - """ - - def __init__( - self, - file_path: str, - file_wrapper_config: dict, - filesystem_wrapper: AbstractFileSystemWrapper, - ): - """Init binary file wrapper. - - Args: - file_path (str): Path to file - file_wrapper_config (dict): File wrapper config - filesystem_wrapper (AbstractFileSystemWrapper): File system wrapper to abstract storage of the file - - Raises: - ValueError: If the file has the wrong file extension - ValueError: If the file does not contain an exact number of samples of given size - """ - super().__init__(file_path, file_wrapper_config, filesystem_wrapper) - self.file_wrapper_type = FileWrapperType.BinaryFileWrapper - self.byteorder = file_wrapper_config["byteorder"] - - self.record_size = file_wrapper_config["record_size"] - self.label_size = file_wrapper_config["label_size"] - if self.record_size - self.label_size < 1: - raise ValueError("Each record must have at least 1 byte of data other than the label.") - - self._validate_file_extension() - self.file_size = self.filesystem_wrapper.get_size(self.file_path) - if self.file_size % self.record_size != 0: - raise ValueError("File does not contain exact number of records of size " + str(self.record_size)) - - def _validate_file_extension(self) -> None: - """Validates the file extension as bin - - Raises: - ValueError: File has wrong file extension - """ - if not self.file_path.endswith(".bin"): - raise ValueError("File has wrong file extension.") - - def _validate_request_indices(self, total_samples: int, indices: list) -> None: - """Validates if the requested indices are in the range of total number of samples - in the file - - Args: - total_samples: Total number of samples in the file - indices (list): List of indices of the required samples - - Raises: - IndexError: If the index is out of bounds - """ - invalid_indices = any((idx < 0 or idx > (total_samples - 1)) for idx in indices) - if invalid_indices: - raise IndexError("Indices are out of range. Indices should be between 0 and " + str(total_samples)) - - def get_number_of_samples(self) -> int: - """Get number of samples in file. - - Returns: - int: Number of samples in file - """ - return int(self.file_size / self.record_size) - - def get_label(self, index: int) -> int: - """Get the label of the sample at the given index. - - Args: - index (int): Index - - Raises: - IndexError: If the index is out of bounds - - Returns: - int: Label for the sample - """ - data = self.filesystem_wrapper.get(self.file_path) - - total_samples = self.get_number_of_samples() - self._validate_request_indices(total_samples, [index]) - - record_start = index * self.record_size - lable_bytes = data[record_start : record_start + self.label_size] - return int.from_bytes(lable_bytes, byteorder=self.byteorder) - - def get_all_labels(self) -> list[int]: - """Returns a list of all labels of all samples in the file. - - Returns: - list[int]: List of labels - """ - data = self.filesystem_wrapper.get(self.file_path) - num_samples = self.get_number_of_samples() - labels = [ - int.from_bytes( - data[(idx * self.record_size) : (idx * self.record_size) + self.label_size], byteorder=self.byteorder - ) - for idx in range(num_samples) - ] - return labels - - def get_sample(self, index: int) -> bytes: - """Get the sample at the given index. - The indices are zero based. - - Args: - index (int): Index - - Raises: - IndexError: If the index is out of bounds - - Returns: - bytes: Sample - """ - return self.get_samples_from_indices([index])[0] - - def get_samples(self, start: int, end: int) -> list[bytes]: - """Get the samples at the given range from start (inclusive) to end (exclusive). - The indices are zero based. - - Args: - start (int): Start index - end (int): End index - - Raises: - IndexError: If the index is out of bounds - - Returns: - bytes: Sample - """ - return self.get_samples_from_indices(list(range(start, end))) - - def get_samples_from_indices(self, indices: list) -> list[bytes]: - """Get the samples at the given index list. - The indices are zero based. - - Args: - indices (list): List of indices of the required samples - - Raises: - IndexError: If the index is out of bounds - - Returns: - bytes: Sample - """ - data = self.filesystem_wrapper.get(self.file_path) - - total_samples = len(data) / self.record_size - self._validate_request_indices(total_samples, indices) - - samples = [data[(idx * self.record_size) + self.label_size : (idx + 1) * self.record_size] for idx in indices] - return samples - - def delete_samples(self, indices: list) -> None: - """Delete the samples at the given index list. - The indices are zero based. - - We do not support deleting samples from binary files. - We can only delete the entire file which is done when every sample is deleted. - This is done to avoid the overhead of updating the file after every deletion. - - See remove_empty_files in the storage grpc servicer for more details. - - Args: - indices (list): List of indices of the samples to delete - """ - return diff --git a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py b/modyn/storage/internal/file_wrapper/csv_file_wrapper.py deleted file mode 100644 index 355fb5918..000000000 --- a/modyn/storage/internal/file_wrapper/csv_file_wrapper.py +++ /dev/null @@ -1,193 +0,0 @@ -import csv -from typing import Iterator, Optional - -from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper - - -class CsvFileWrapper(AbstractFileWrapper): - def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper: AbstractFileSystemWrapper): - super().__init__(file_path, file_wrapper_config, filesystem_wrapper) - - self.file_wrapper_type = FileWrapperType.CsvFileWrapper - - if "separator" in file_wrapper_config: - self.separator = file_wrapper_config["separator"] - else: - self.separator = "," - - if "label_index" not in file_wrapper_config: - raise ValueError("Please specify the index of the column that contains the label. ") - if not isinstance(file_wrapper_config["label_index"], int) or file_wrapper_config["label_index"] < 0: - raise ValueError("The label_index must be a positive integer.") - self.label_index = file_wrapper_config["label_index"] - - # the first line might contain the header, which is useless and must not be returned. - if "ignore_first_line" in file_wrapper_config: - self.ignore_first_line = file_wrapper_config["ignore_first_line"] - else: - self.ignore_first_line = False - - if "encoding" in file_wrapper_config: - self.encoding = file_wrapper_config["encoding"] - else: - self.encoding = "utf-8" - - # check that the file is actually a CSV - self._validate_file_extension() - - # do not validate the content only if "validate_file_content" is explicitly set to False - if ("validate_file_content" not in file_wrapper_config) or ( - "validate_file_content" in file_wrapper_config and file_wrapper_config["validate_file_content"] - ): - self._validate_file_content() - - def _validate_file_extension(self) -> None: - """Validates the file extension as csv - - Raises: - ValueError: File has wrong file extension - """ - if not self.file_path.endswith(".csv"): - raise ValueError("File has wrong file extension.") - - def _validate_file_content(self) -> None: - """ - Performs the following checks: - - specified label column is castable to integer - - each row has the label_index_column - - each row has the same width - - Raises a ValueError if a condition is not met - """ - - reader = self._get_csv_reader() - - number_of_columns = [] - - for row in reader: - number_of_columns.append(len(row)) - if not 0 <= self.label_index < len(row): - raise ValueError("Label index outside row boundary") - if not row[self.label_index].isnumeric(): # returns true iff all the characters are numbers - raise ValueError("The label must be an integer") - - if len(set(number_of_columns)) != 1: - raise ValueError( - "Some rows have different width. " f"This is the number of columns row by row {number_of_columns}" - ) - - def get_sample(self, index: int) -> bytes: - samples = self._filter_rows_samples([index]) - - if len(samples) != 1: - raise IndexError("Invalid index") - - return samples[0] - - def get_samples(self, start: int, end: int) -> list[bytes]: - indices = list(range(start, end)) - return self.get_samples_from_indices(indices) - - def get_samples_from_indices(self, indices: list) -> list[bytes]: - return self._filter_rows_samples(indices) - - def get_label(self, index: int) -> int: - labels = self._filter_rows_labels([index]) - - if len(labels) != 1: - raise IndexError("Invalid index.") - - return labels[0] - - def get_all_labels(self) -> list[int]: - reader = self._get_csv_reader() - labels = [int(row[self.label_index]) for row in reader] - return labels - - def get_number_of_samples(self) -> int: - reader = self._get_csv_reader() - return sum(1 for _ in reader) - - def _get_csv_reader(self) -> Iterator: - """ - Receives the bytes from the file_system_wrapper and creates a csv.reader out of it. - Returns: - csv.reader - """ - data_file = self.filesystem_wrapper.get(self.file_path) - - # Convert bytes content to a string - data_file_str = data_file.decode(self.encoding) - - lines = data_file_str.split("\n") - - # Create a CSV reader - reader = csv.reader(lines, delimiter=self.separator) - - # skip the header if required - if self.ignore_first_line: - next(reader) - - return reader - - def _filter_rows_samples(self, indices: list[int]) -> list[bytes]: - """ - Filters the selected rows and removes the label column - Args: - indices: list of rows that must be kept - - Returns: - list of byte-encoded rows - - """ - assert len(indices) == len(set(indices)), "An index is required more than once." - reader = self._get_csv_reader() - - # Iterate over the rows and keep the selected ones - filtered_rows: list[Optional[bytes]] = [None] * len(indices) - for i, row in enumerate(reader): - if i in indices: - # Remove the label, convert the row to bytes and append to the list - row_without_label = [col for j, col in enumerate(row) if j != self.label_index] - # the row is transformed in a similar csv using the same separator and then transformed to bytes - filtered_rows[indices.index(i)] = bytes(self.separator.join(row_without_label), self.encoding) - - if sum(1 for el in filtered_rows if el is None) != 0: - raise IndexError("At least one index is invalid") - - # Here mypy complains that filtered_rows is a list of list[Optional[bytes]], - # that can't happen given the above exception - return filtered_rows # type: ignore - - def _filter_rows_labels(self, indices: list[int]) -> list[int]: - """ - Filters the selected rows and extracts the label column - Args: - indices: list of rows that must be kept - - Returns: - list of labels - - """ - assert len(indices) == len(set(indices)), "An index is required more than once." - reader = self._get_csv_reader() - - # Iterate over the rows and keep the selected ones - filtered_rows: list[Optional[int]] = [None] * len(indices) - for i, row in enumerate(reader): - if i in indices: - # labels are integer in modyn - int_label = int(row[self.label_index]) - filtered_rows[indices.index(i)] = int_label - - if sum(1 for el in filtered_rows if el is None) != 0: - raise IndexError("At least one index is invalid") - - # Here mypy complains that filtered_rows is a list of list[Optional[bytes]], - # that can't happen given the above exception - return filtered_rows # type: ignore - - def delete_samples(self, indices: list) -> None: - pass diff --git a/modyn/storage/internal/file_wrapper/file_wrapper_type.py b/modyn/storage/internal/file_wrapper/file_wrapper_type.py deleted file mode 100644 index 758b99cbe..000000000 --- a/modyn/storage/internal/file_wrapper/file_wrapper_type.py +++ /dev/null @@ -1,27 +0,0 @@ -"""File wrapper type enum and exception.""" - -from enum import Enum - - -class FileWrapperType(Enum): - """Enum for the type of file wrapper. - - Important: The value of the enum must be the same as the name of the module. - The name of the enum must be the same as the name of the class. - """ - - SingleSampleFileWrapper = "single_sample_file_wrapper" # pylint: disable=invalid-name - BinaryFileWrapper = "binary_file_wrapper" # pylint: disable=invalid-name - CsvFileWrapper = "csv_file_wrapper" # pylint: disable=invalid-name - - -class InvalidFileWrapperTypeException(Exception): - """Invalid file wrapper type exception.""" - - def __init__(self, message: str): - """Init exception. - - Args: - message (str): Exception message - """ - super().__init__(message) diff --git a/modyn/storage/internal/file_wrapper/single_sample_file_wrapper.py b/modyn/storage/internal/file_wrapper/single_sample_file_wrapper.py deleted file mode 100644 index b605f93a3..000000000 --- a/modyn/storage/internal/file_wrapper/single_sample_file_wrapper.py +++ /dev/null @@ -1,136 +0,0 @@ -"""A file wrapper for files that contains only one sample and metadata.""" - -import logging -import pathlib -from typing import Optional - -from modyn.storage.internal.file_wrapper.abstract_file_wrapper import AbstractFileWrapper -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper - -logger = logging.getLogger(__name__) - - -class SingleSampleFileWrapper(AbstractFileWrapper): - """A file wrapper for files that contains only one sample and metadata. - - For example, a file that contains only one image and metadata. - The metadata is stored in a json file with the same name as the image file. - """ - - def __init__(self, file_path: str, file_wrapper_config: dict, filesystem_wrapper: AbstractFileSystemWrapper): - """Init file wrapper. - - Args: - file_path (str): File path - file_wrapper_config (dict): File wrapper config - filesystem_wrapper (AbstractFileSystemWrapper): File system wrapper to abstract storage of the file - """ - super().__init__(file_path, file_wrapper_config, filesystem_wrapper) - self.file_wrapper_type = FileWrapperType.SingleSampleFileWrapper - - def get_number_of_samples(self) -> int: - """Get the size of the file in number of samples. - - If the file has the correct file extension, it contains only one sample. - - Returns: - int: Number of samples - """ - if not self.file_path.endswith(self.file_wrapper_config["file_extension"]): - return 0 - return 1 - - def get_samples(self, start: int, end: int) -> list[bytes]: - """Get the samples from the file. - - Args: - start (int): start index - end (int): end index - - Raises: - IndexError: If the start and end index are not 0 and 1 - - Returns: - bytes: Samples - """ - if start != 0 or end != 1: - raise IndexError("SingleSampleFileWrapper contains only one sample.") - return [self.get_sample(0)] - - def get_sample(self, index: int) -> bytes: - r"""Return the sample as bytes. - - Args: - index (int): Index - - Raises: - ValueError: If the file has the wrong file extension - IndexError: If the index is not 0 - - Returns: - bytes: Sample - """ - if self.get_number_of_samples() == 0: - raise ValueError("File has wrong file extension.") - if index != 0: - raise IndexError("SingleSampleFileWrapper contains only one sample.") - data_file = self.filesystem_wrapper.get(self.file_path) - return data_file - - def get_label(self, index: int) -> Optional[int]: - """Get the label of the sample at the given index. - - Args: - index (int): Index - - Raises: - ValueError: If the file has the wrong file extension - IndexError: If the index is not 0 - - Returns: - int: Label if exists, else None - """ - if self.get_number_of_samples() == 0: - raise ValueError("File has wrong file extension.") - if index != 0: - raise IndexError("SingleSampleFileWrapper contains only one sample.") - if ( - "label_file_extension" not in self.file_wrapper_config - or self.file_wrapper_config["label_file_extension"] is None - ): - logger.warning("No label file extension defined.") - return None - label_path = pathlib.Path(self.file_path).with_suffix(self.file_wrapper_config["label_file_extension"]) - label = self.filesystem_wrapper.get(label_path) - if label is not None: - label = label.decode("utf-8") - return int(label) - return None - - def get_all_labels(self) -> list[Optional[int]]: - """Returns a list of all labels of all samples in the file. - - Returns: - list[Optional[int]]: List of labels - """ - return [self.get_label(0)] - - def get_samples_from_indices(self, indices: list) -> list[bytes]: - """Get the samples from the file. - - Args: - indices (list): Indices - - Raises: - IndexError: If the indices are not valid - - Returns: - bytes: Samples - """ - if len(indices) != 1 or indices[0] != 0: - raise IndexError("SingleSampleFileWrapper contains only one sample.") - return [self.get_sample(0)] - - def delete_samples(self, indices: list) -> None: - self.filesystem_wrapper.delete(self.file_path) diff --git a/modyn/storage/internal/filesystem_wrapper/__init__.py b/modyn/storage/internal/filesystem_wrapper/__init__.py deleted file mode 100644 index c6005a336..000000000 --- a/modyn/storage/internal/filesystem_wrapper/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""This package contains the file system wrapper classes. - -The file system wrapper classes are used to abstract the file system -operations. This allows the storage module to be used with different file systems. -""" -import os - -files = os.listdir(os.path.dirname(__file__)) -files.remove("__init__.py") -__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/internal/filesystem_wrapper/abstract_filesystem_wrapper.py b/modyn/storage/internal/filesystem_wrapper/abstract_filesystem_wrapper.py deleted file mode 100644 index 5b759127f..000000000 --- a/modyn/storage/internal/filesystem_wrapper/abstract_filesystem_wrapper.py +++ /dev/null @@ -1,177 +0,0 @@ -"""Abstract filesystem wrapper class.""" - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Union - -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType - - -class AbstractFileSystemWrapper(ABC): - """Base class for all filesystem wrappers.""" - - filesystem_wrapper_type: FilesystemWrapperType = None - - def __init__(self, base_path: str): - """Init filesystem wrapper. - - Args: - base_path (str): Base path of filesystem - """ - self.base_path = base_path - - def get(self, path: Union[str, Path]) -> bytes: - """Get file content. - - Args: - path (Union[str, Path]): Absolute path to file - - Returns: - bytes: File content - """ - return self._get(str(path)) - - @abstractmethod - def _get(self, path: str) -> bytes: - """Get file content. - - Args: - path (str): Absolute path to file - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - bytes: File content - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def exists(self, path: str) -> bool: - """Exists checks whether the given path exists or not. - - Args: - path (str): Absolute path to file or directory - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - bool: True if path exists, False otherwise - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def list(self, path: str, recursive: bool = False) -> list[str]: - """List files in directory. - - Args: - path (str): Absolute path to directory - recursive (bool, optional): Recursively list files. Defaults to False. - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - list[str]: List of files - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def isdir(self, path: str) -> bool: - """Return `True` if the path is a directory. - - Args: - path (str): Absolute path to file - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - bool: True if path is a directory, False otherwise - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def isfile(self, path: str) -> bool: - """Return `True` if the path is a file. - - Args: - path (str): Absolute path to file - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - bool: True if path is a file, False otherwise - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_size(self, path: str) -> int: - """Return the size of the file. - - Args: - path (str): Absolute path to file - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - int: Size of file - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_modified(self, path: str) -> int: - """Return the last modified time of the file. - - Args: - path (str): Absolute path to file - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - int: Last modified time - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_created(self, path: str) -> int: - """Return the creation time of the file. - - Args: - path (str): Absolute path to file - - Raises: - NotImplementedError: If the method is not implemented - - Returns: - int: Creation time - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def join(self, *paths: str) -> str: - """Join paths. - - Raises: - NotImplementedError: If not implemented - - Returns: - str: Joined path - """ - raise NotImplementedError # pragma: no cover - - @abstractmethod - def delete(self, path: str) -> None: - """Delete file. - - Args: - path (str): Absolute path to file - - Raises: - NotImplementedError: If the method is not implemented - """ - raise NotImplementedError diff --git a/modyn/storage/internal/filesystem_wrapper/filesystem_wrapper_type.py b/modyn/storage/internal/filesystem_wrapper/filesystem_wrapper_type.py deleted file mode 100644 index 7213e2c1d..000000000 --- a/modyn/storage/internal/filesystem_wrapper/filesystem_wrapper_type.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Filesystem wrapper type and exception.""" -from enum import Enum - - -class FilesystemWrapperType(Enum): - """Enum for the type of file system wrapper. - - Important: The value of the enum must be the same as the name of the module. - The name of the enum must be the same as the name of the class. - """ - - LocalFilesystemWrapper = "local_filesystem_wrapper" # pylint: disable=invalid-name - - -class InvalidFilesystemWrapperTypeException(Exception): - """Exception for invalid filesystem wrapper type.""" - - def __init__(self, message: str): - """Init exception. - - Args: - message (str): Exception message - """ - super().__init__(message) diff --git a/modyn/storage/internal/filesystem_wrapper/local_filesystem_wrapper.py b/modyn/storage/internal/filesystem_wrapper/local_filesystem_wrapper.py deleted file mode 100644 index 4e6d6a818..000000000 --- a/modyn/storage/internal/filesystem_wrapper/local_filesystem_wrapper.py +++ /dev/null @@ -1,177 +0,0 @@ -"""Local filesystem wrapper. - -This module contains the local filesystem wrapper. -It is used to access files on the local filesystem. -""" -import os - -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType - - -class LocalFilesystemWrapper(AbstractFileSystemWrapper): - """Local filesystem wrapper.""" - - def __init__(self, base_path: str): - """Init local filesystem wrapper. - - Args: - base_path (str): Base path of local filesystem - """ - super().__init__(base_path) - self.filesystem_wrapper_type = FilesystemWrapperType.LocalFilesystemWrapper - - def __is_valid_path(self, path: str) -> bool: - return path.startswith(self.base_path) - - def _get(self, path: str) -> bytes: - """Get file content. - - Args: - path (str): Absolute path to file - - Raises: - FileNotFoundError: If path is not valid - IsADirectoryError: If path is a directory - - Returns: - bytes: File content - """ - if not self.__is_valid_path(path): - raise ValueError(f"Path {path} is not valid.") - if not self.isfile(path): - raise IsADirectoryError(f"Path {path} is a directory.") - with open(path, "rb") as file: - return file.read() - - def exists(self, path: str) -> bool: - """Check if path exists. - - Args: - path (str): Absolute path to file or directory - - Returns: - bool: True if path exists, False otherwise - """ - return os.path.exists(path) - - def list(self, path: str, recursive: bool = False) -> list[str]: - """List files in directory. - - Args: - path (str): Absolute path to directory - recursive (bool, optional): List files recursively. Defaults to False. - - Raises: - ValueError: If path is not valid - NotADirectoryError: If path is not a directory - - Returns: - list[str]: List of files in directory - """ - if not self.__is_valid_path(path): - raise ValueError(f"Path {path} is not valid.") - if not self.isdir(path): - raise NotADirectoryError(f"Path {path} is not a directory.") - if recursive: - return [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(path)) for f in fn] - return os.listdir(path) - - def isdir(self, path: str) -> bool: - """Check if path is a directory. - - Args: - path (str): Absolute path to directory - - Returns: - bool: True if path is a directory, False otherwise - """ - return os.path.isdir(path) - - def isfile(self, path: str) -> bool: - """Check if path is a file. - - Args: - path (str): Absolute path to file - - Returns: - bool: True if path is a file, False otherwise - """ - return os.path.isfile(path) - - def get_size(self, path: str) -> int: - """Get size of file. - - Args: - path (str): Absolute path to file - - Raises: - ValueError: If path is not valid - IsADirectoryError: If path is a directory - - Returns: - int: Size of file in bytes - """ - if not self.__is_valid_path(path): - raise ValueError(f"Path {path} is not valid.") - if not self.isfile(path): - raise IsADirectoryError(f"Path {path} is a directory.") - return os.path.getsize(path) - - def get_modified(self, path: str) -> int: - """Get modification time of file. - - Args: - path (str): Absolute path to file - - Raises: - ValueError: If path is not valid - IsADirectoryError: If path is a directory - - Returns: - int: Modification time in milliseconds rounded to the nearest integer - """ - if not self.__is_valid_path(path): - raise ValueError(f"Path {path} is not valid.") - if not self.isfile(path): - raise IsADirectoryError(f"Path {path} is a directory.") - return int(os.path.getmtime(path) * 1000) - - def get_created(self, path: str) -> int: - """Get creation time of file. - - Args: - path (str): Absolute path to file - - Raises: - ValueError: If path is not valid - IsADirectoryError: If path is a directory - - Returns: - int: Creation time in milliseconds rounded to the nearest integer - """ - if not self.__is_valid_path(path): - raise ValueError(f"Path {path} is not valid.") - if not self.isfile(path): - raise IsADirectoryError(f"Path {path} is a directory.") - return int(os.path.getctime(path) * 1000) - - def join(self, *paths: str) -> str: - """Join paths. - - Returns: - str: Joined path - """ - return os.path.join(*paths) - - def delete(self, path: str) -> None: - """Delete file. - - Args: - path (str): Absolute path to file - - Raises: - ValueError: If path is not valid - IsADirectoryError: If path is a directory - """ - return os.remove(path) diff --git a/modyn/storage/internal/grpc/__init__.py b/modyn/storage/internal/grpc/__init__.py deleted file mode 100644 index 4e54d865f..000000000 --- a/modyn/storage/internal/grpc/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Storage module. - -The storage module contains all classes and functions related to the storage and retrieval of data. -""" - -import os - -files = os.listdir(os.path.dirname(__file__)) -files.remove("__init__.py") -__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/storage/internal/grpc/generated/__init__.py b/modyn/storage/internal/grpc/generated/__init__.py index 4e54d865f..982984594 100644 --- a/modyn/storage/internal/grpc/generated/__init__.py +++ b/modyn/storage/internal/grpc/generated/__init__.py @@ -1,6 +1,7 @@ -"""Storage module. +""" +Storage module. -The storage module contains all classes and functions related to the storage and retrieval of data. +The storage module contains all classes and functions related the evaluation of models. """ import os diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.py b/modyn/storage/internal/grpc/generated/storage_pb2.py index 7b6de42fc..3908182d6 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2.py @@ -14,49 +14,53 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rstorage.proto\x12\rmodyn.storage\x1a\x1bgoogle/protobuf/empty.proto\".\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03\"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03\"W\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\"+\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03\"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03\"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03\"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03\"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xcf\x07\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse\"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse\"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse\"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse\"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse\"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse\"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse\"\x00\x12[\n\x13GetCurrentTimestamp\x12\x16.google.protobuf.Empty\x1a*.modyn.storage.GetCurrentTimestampResponse\"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse\"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\rstorage.proto\x12\rmodyn.storage".\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\x1c\n\x1aGetCurrentTimestampRequest"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"W\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03"+\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe2\x07\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse"\x00\x62\x06proto3' +) -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'storage_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "storage_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _GETREQUEST._serialized_start=61 - _GETREQUEST._serialized_end=107 - _GETRESPONSE._serialized_start=109 - _GETRESPONSE._serialized_end=169 - _GETNEWDATASINCEREQUEST._serialized_start=171 - _GETNEWDATASINCEREQUEST._serialized_end=234 - _GETNEWDATASINCERESPONSE._serialized_start=236 - _GETNEWDATASINCERESPONSE._serialized_end=311 - _GETDATAININTERVALREQUEST._serialized_start=313 - _GETDATAININTERVALREQUEST._serialized_end=407 - _GETDATAININTERVALRESPONSE._serialized_start=409 - _GETDATAININTERVALRESPONSE._serialized_end=486 - _GETDATAPERWORKERREQUEST._serialized_start=488 - _GETDATAPERWORKERREQUEST._serialized_end=575 - _GETDATAPERWORKERRESPONSE._serialized_start=577 - _GETDATAPERWORKERRESPONSE._serialized_end=617 - _GETDATASETSIZEREQUEST._serialized_start=619 - _GETDATASETSIZEREQUEST._serialized_end=662 - _GETDATASETSIZERESPONSE._serialized_start=664 - _GETDATASETSIZERESPONSE._serialized_end=723 - _DATASETAVAILABLEREQUEST._serialized_start=725 - _DATASETAVAILABLEREQUEST._serialized_end=770 - _DATASETAVAILABLERESPONSE._serialized_start=772 - _DATASETAVAILABLERESPONSE._serialized_end=817 - _REGISTERNEWDATASETREQUEST._serialized_start=820 - _REGISTERNEWDATASETREQUEST._serialized_end=1075 - _REGISTERNEWDATASETRESPONSE._serialized_start=1077 - _REGISTERNEWDATASETRESPONSE._serialized_end=1122 - _GETCURRENTTIMESTAMPRESPONSE._serialized_start=1124 - _GETCURRENTTIMESTAMPRESPONSE._serialized_end=1172 - _DELETEDATASETRESPONSE._serialized_start=1174 - _DELETEDATASETRESPONSE._serialized_end=1214 - _DELETEDATAREQUEST._serialized_start=1216 - _DELETEDATAREQUEST._serialized_end=1269 - _DELETEDATARESPONSE._serialized_start=1271 - _DELETEDATARESPONSE._serialized_end=1308 - _STORAGE._serialized_start=1311 - _STORAGE._serialized_end=2286 + DESCRIPTOR._options = None + _globals["_GETREQUEST"]._serialized_start = 32 + _globals["_GETREQUEST"]._serialized_end = 78 + _globals["_GETRESPONSE"]._serialized_start = 80 + _globals["_GETRESPONSE"]._serialized_end = 140 + _globals["_GETCURRENTTIMESTAMPREQUEST"]._serialized_start = 142 + _globals["_GETCURRENTTIMESTAMPREQUEST"]._serialized_end = 170 + _globals["_GETNEWDATASINCEREQUEST"]._serialized_start = 172 + _globals["_GETNEWDATASINCEREQUEST"]._serialized_end = 235 + _globals["_GETNEWDATASINCERESPONSE"]._serialized_start = 237 + _globals["_GETNEWDATASINCERESPONSE"]._serialized_end = 312 + _globals["_GETDATAININTERVALREQUEST"]._serialized_start = 314 + _globals["_GETDATAININTERVALREQUEST"]._serialized_end = 408 + _globals["_GETDATAININTERVALRESPONSE"]._serialized_start = 410 + _globals["_GETDATAININTERVALRESPONSE"]._serialized_end = 487 + _globals["_GETDATAPERWORKERREQUEST"]._serialized_start = 489 + _globals["_GETDATAPERWORKERREQUEST"]._serialized_end = 576 + _globals["_GETDATAPERWORKERRESPONSE"]._serialized_start = 578 + _globals["_GETDATAPERWORKERRESPONSE"]._serialized_end = 618 + _globals["_GETDATASETSIZEREQUEST"]._serialized_start = 620 + _globals["_GETDATASETSIZEREQUEST"]._serialized_end = 663 + _globals["_GETDATASETSIZERESPONSE"]._serialized_start = 665 + _globals["_GETDATASETSIZERESPONSE"]._serialized_end = 724 + _globals["_DATASETAVAILABLEREQUEST"]._serialized_start = 726 + _globals["_DATASETAVAILABLEREQUEST"]._serialized_end = 771 + _globals["_DATASETAVAILABLERESPONSE"]._serialized_start = 773 + _globals["_DATASETAVAILABLERESPONSE"]._serialized_end = 818 + _globals["_REGISTERNEWDATASETREQUEST"]._serialized_start = 821 + _globals["_REGISTERNEWDATASETREQUEST"]._serialized_end = 1076 + _globals["_REGISTERNEWDATASETRESPONSE"]._serialized_start = 1078 + _globals["_REGISTERNEWDATASETRESPONSE"]._serialized_end = 1123 + _globals["_GETCURRENTTIMESTAMPRESPONSE"]._serialized_start = 1125 + _globals["_GETCURRENTTIMESTAMPRESPONSE"]._serialized_end = 1173 + _globals["_DELETEDATASETRESPONSE"]._serialized_start = 1175 + _globals["_DELETEDATASETRESPONSE"]._serialized_end = 1215 + _globals["_DELETEDATAREQUEST"]._serialized_start = 1217 + _globals["_DELETEDATAREQUEST"]._serialized_end = 1270 + _globals["_DELETEDATARESPONSE"]._serialized_start = 1272 + _globals["_DELETEDATARESPONSE"]._serialized_end = 1309 + _globals["_STORAGE"]._serialized_start = 1312 + _globals["_STORAGE"]._serialized_end = 2306 # @@protoc_insertion_point(module_scope) diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.pyi b/modyn/storage/internal/grpc/generated/storage_pb2.pyi index 6a6c293fe..5fa510ffc 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.pyi +++ b/modyn/storage/internal/grpc/generated/storage_pb2.pyi @@ -24,14 +24,23 @@ class GetRequest(google.protobuf.message.Message): KEYS_FIELD_NUMBER: builtins.int dataset_id: builtins.str @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... def __init__( self, *, dataset_id: builtins.str = ..., keys: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dataset_id", b"dataset_id", "keys", b"keys" + ], + ) -> None: ... global___GetRequest = GetRequest @@ -43,11 +52,23 @@ class GetResponse(google.protobuf.message.Message): KEYS_FIELD_NUMBER: builtins.int LABELS_FIELD_NUMBER: builtins.int @property - def samples(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def samples( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.bytes + ]: ... @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... @property - def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def labels( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... def __init__( self, *, @@ -55,10 +76,27 @@ class GetResponse(google.protobuf.message.Message): keys: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "keys", b"keys", "labels", b"labels", "samples", b"samples" + ], + ) -> None: ... global___GetResponse = GetResponse +@typing_extensions.final +class GetCurrentTimestampRequest(google.protobuf.message.Message): + """https://github.com/grpc/grpc/issues/15937""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___GetCurrentTimestampRequest = GetCurrentTimestampRequest + @typing_extensions.final class GetNewDataSinceRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -73,7 +111,12 @@ class GetNewDataSinceRequest(google.protobuf.message.Message): dataset_id: builtins.str = ..., timestamp: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dataset_id", b"dataset_id", "timestamp", b"timestamp" + ], + ) -> None: ... global___GetNewDataSinceRequest = GetNewDataSinceRequest @@ -85,11 +128,23 @@ class GetNewDataSinceResponse(google.protobuf.message.Message): TIMESTAMPS_FIELD_NUMBER: builtins.int LABELS_FIELD_NUMBER: builtins.int @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... @property - def timestamps(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def timestamps( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... @property - def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def labels( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... def __init__( self, *, @@ -97,7 +152,12 @@ class GetNewDataSinceResponse(google.protobuf.message.Message): timestamps: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "keys", b"keys", "labels", b"labels", "timestamps", b"timestamps" + ], + ) -> None: ... global___GetNewDataSinceResponse = GetNewDataSinceResponse @@ -118,7 +178,17 @@ class GetDataInIntervalRequest(google.protobuf.message.Message): start_timestamp: builtins.int = ..., end_timestamp: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dataset_id", + b"dataset_id", + "end_timestamp", + b"end_timestamp", + "start_timestamp", + b"start_timestamp", + ], + ) -> None: ... global___GetDataInIntervalRequest = GetDataInIntervalRequest @@ -130,11 +200,23 @@ class GetDataInIntervalResponse(google.protobuf.message.Message): TIMESTAMPS_FIELD_NUMBER: builtins.int LABELS_FIELD_NUMBER: builtins.int @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... @property - def timestamps(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def timestamps( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... @property - def labels(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def labels( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... def __init__( self, *, @@ -142,7 +224,12 @@ class GetDataInIntervalResponse(google.protobuf.message.Message): timestamps: collections.abc.Iterable[builtins.int] | None = ..., labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "keys", b"keys", "labels", b"labels", "timestamps", b"timestamps" + ], + ) -> None: ... global___GetDataInIntervalResponse = GetDataInIntervalResponse @@ -163,7 +250,17 @@ class GetDataPerWorkerRequest(google.protobuf.message.Message): worker_id: builtins.int = ..., total_workers: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "total_workers", b"total_workers", "worker_id", b"worker_id"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dataset_id", + b"dataset_id", + "total_workers", + b"total_workers", + "worker_id", + b"worker_id", + ], + ) -> None: ... global___GetDataPerWorkerRequest = GetDataPerWorkerRequest @@ -173,13 +270,19 @@ class GetDataPerWorkerResponse(google.protobuf.message.Message): KEYS_FIELD_NUMBER: builtins.int @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... def __init__( self, *, keys: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["keys", b"keys"] + ) -> None: ... global___GetDataPerWorkerResponse = GetDataPerWorkerResponse @@ -194,7 +297,9 @@ class GetDatasetSizeRequest(google.protobuf.message.Message): *, dataset_id: builtins.str = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id"] + ) -> None: ... global___GetDatasetSizeRequest = GetDatasetSizeRequest @@ -212,7 +317,12 @@ class GetDatasetSizeResponse(google.protobuf.message.Message): success: builtins.bool = ..., num_keys: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["num_keys", b"num_keys", "success", b"success"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "num_keys", b"num_keys", "success", b"success" + ], + ) -> None: ... global___GetDatasetSizeResponse = GetDatasetSizeResponse @@ -227,7 +337,9 @@ class DatasetAvailableRequest(google.protobuf.message.Message): *, dataset_id: builtins.str = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id"] + ) -> None: ... global___DatasetAvailableRequest = DatasetAvailableRequest @@ -242,7 +354,9 @@ class DatasetAvailableResponse(google.protobuf.message.Message): *, available: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["available", b"available"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["available", b"available"] + ) -> None: ... global___DatasetAvailableResponse = DatasetAvailableResponse @@ -281,7 +395,29 @@ class RegisterNewDatasetRequest(google.protobuf.message.Message): ignore_last_timestamp: builtins.bool = ..., file_watcher_interval: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["base_path", b"base_path", "dataset_id", b"dataset_id", "description", b"description", "file_watcher_interval", b"file_watcher_interval", "file_wrapper_config", b"file_wrapper_config", "file_wrapper_type", b"file_wrapper_type", "filesystem_wrapper_type", b"filesystem_wrapper_type", "ignore_last_timestamp", b"ignore_last_timestamp", "version", b"version"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "base_path", + b"base_path", + "dataset_id", + b"dataset_id", + "description", + b"description", + "file_watcher_interval", + b"file_watcher_interval", + "file_wrapper_config", + b"file_wrapper_config", + "file_wrapper_type", + b"file_wrapper_type", + "filesystem_wrapper_type", + b"filesystem_wrapper_type", + "ignore_last_timestamp", + b"ignore_last_timestamp", + "version", + b"version", + ], + ) -> None: ... global___RegisterNewDatasetRequest = RegisterNewDatasetRequest @@ -296,7 +432,9 @@ class RegisterNewDatasetResponse(google.protobuf.message.Message): *, success: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["success", b"success"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["success", b"success"] + ) -> None: ... global___RegisterNewDatasetResponse = RegisterNewDatasetResponse @@ -311,7 +449,9 @@ class GetCurrentTimestampResponse(google.protobuf.message.Message): *, timestamp: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["timestamp", b"timestamp"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["timestamp", b"timestamp"] + ) -> None: ... global___GetCurrentTimestampResponse = GetCurrentTimestampResponse @@ -326,7 +466,9 @@ class DeleteDatasetResponse(google.protobuf.message.Message): *, success: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["success", b"success"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["success", b"success"] + ) -> None: ... global___DeleteDatasetResponse = DeleteDatasetResponse @@ -338,14 +480,23 @@ class DeleteDataRequest(google.protobuf.message.Message): KEYS_FIELD_NUMBER: builtins.int dataset_id: builtins.str @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[ + builtins.int + ]: ... def __init__( self, *, dataset_id: builtins.str = ..., keys: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dataset_id", b"dataset_id", "keys", b"keys" + ], + ) -> None: ... global___DeleteDataRequest = DeleteDataRequest @@ -360,6 +511,8 @@ class DeleteDataResponse(google.protobuf.message.Message): *, success: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["success", b"success"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["success", b"success"] + ) -> None: ... global___DeleteDataResponse = DeleteDataResponse diff --git a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py index def2727e9..ec0e263f0 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py @@ -15,55 +15,55 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Get = channel.unary_stream( - '/modyn.storage.Storage/Get', - request_serializer=storage__pb2.GetRequest.SerializeToString, - response_deserializer=storage__pb2.GetResponse.FromString, - ) + "/modyn.storage.Storage/Get", + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponse.FromString, + ) self.GetNewDataSince = channel.unary_stream( - '/modyn.storage.Storage/GetNewDataSince', - request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, - response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, - ) + "/modyn.storage.Storage/GetNewDataSince", + request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, + response_deserializer=storage__pb2.GetNewDataSinceResponse.FromString, + ) self.GetDataInInterval = channel.unary_stream( - '/modyn.storage.Storage/GetDataInInterval', - request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, - response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, - ) + "/modyn.storage.Storage/GetDataInInterval", + request_serializer=storage__pb2.GetDataInIntervalRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataInIntervalResponse.FromString, + ) self.GetDataPerWorker = channel.unary_stream( - '/modyn.storage.Storage/GetDataPerWorker', - request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, - response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, - ) + "/modyn.storage.Storage/GetDataPerWorker", + request_serializer=storage__pb2.GetDataPerWorkerRequest.SerializeToString, + response_deserializer=storage__pb2.GetDataPerWorkerResponse.FromString, + ) self.GetDatasetSize = channel.unary_unary( - '/modyn.storage.Storage/GetDatasetSize', - request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, - response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, - ) + "/modyn.storage.Storage/GetDatasetSize", + request_serializer=storage__pb2.GetDatasetSizeRequest.SerializeToString, + response_deserializer=storage__pb2.GetDatasetSizeResponse.FromString, + ) self.CheckAvailability = channel.unary_unary( - '/modyn.storage.Storage/CheckAvailability', - request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, - response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, - ) + "/modyn.storage.Storage/CheckAvailability", + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DatasetAvailableResponse.FromString, + ) self.RegisterNewDataset = channel.unary_unary( - '/modyn.storage.Storage/RegisterNewDataset', - request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, - response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, - ) + "/modyn.storage.Storage/RegisterNewDataset", + request_serializer=storage__pb2.RegisterNewDatasetRequest.SerializeToString, + response_deserializer=storage__pb2.RegisterNewDatasetResponse.FromString, + ) self.GetCurrentTimestamp = channel.unary_unary( - '/modyn.storage.Storage/GetCurrentTimestamp', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, - ) + "/modyn.storage.Storage/GetCurrentTimestamp", + request_serializer=storage__pb2.GetCurrentTimestampRequest.SerializeToString, + response_deserializer=storage__pb2.GetCurrentTimestampResponse.FromString, + ) self.DeleteDataset = channel.unary_unary( - '/modyn.storage.Storage/DeleteDataset', - request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, - response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, - ) + "/modyn.storage.Storage/DeleteDataset", + request_serializer=storage__pb2.DatasetAvailableRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDatasetResponse.FromString, + ) self.DeleteData = channel.unary_unary( - '/modyn.storage.Storage/DeleteData', - request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, - response_deserializer=storage__pb2.DeleteDataResponse.FromString, - ) + "/modyn.storage.Storage/DeleteData", + request_serializer=storage__pb2.DeleteDataRequest.SerializeToString, + response_deserializer=storage__pb2.DeleteDataResponse.FromString, + ) class StorageServicer(object): @@ -72,292 +72,413 @@ class StorageServicer(object): def Get(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def GetNewDataSince(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def GetDataInInterval(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def GetDataPerWorker(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def GetDatasetSize(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def CheckAvailability(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def RegisterNewDataset(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def GetCurrentTimestamp(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def DeleteDataset(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def DeleteData(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_StorageServicer_to_server(servicer, server): rpc_method_handlers = { - 'Get': grpc.unary_stream_rpc_method_handler( - servicer.Get, - request_deserializer=storage__pb2.GetRequest.FromString, - response_serializer=storage__pb2.GetResponse.SerializeToString, - ), - 'GetNewDataSince': grpc.unary_stream_rpc_method_handler( - servicer.GetNewDataSince, - request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, - response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, - ), - 'GetDataInInterval': grpc.unary_stream_rpc_method_handler( - servicer.GetDataInInterval, - request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, - response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, - ), - 'GetDataPerWorker': grpc.unary_stream_rpc_method_handler( - servicer.GetDataPerWorker, - request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, - response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, - ), - 'GetDatasetSize': grpc.unary_unary_rpc_method_handler( - servicer.GetDatasetSize, - request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, - response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, - ), - 'CheckAvailability': grpc.unary_unary_rpc_method_handler( - servicer.CheckAvailability, - request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, - response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, - ), - 'RegisterNewDataset': grpc.unary_unary_rpc_method_handler( - servicer.RegisterNewDataset, - request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, - response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, - ), - 'GetCurrentTimestamp': grpc.unary_unary_rpc_method_handler( - servicer.GetCurrentTimestamp, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, - ), - 'DeleteDataset': grpc.unary_unary_rpc_method_handler( - servicer.DeleteDataset, - request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, - response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, - ), - 'DeleteData': grpc.unary_unary_rpc_method_handler( - servicer.DeleteData, - request_deserializer=storage__pb2.DeleteDataRequest.FromString, - response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, - ), + "Get": grpc.unary_stream_rpc_method_handler( + servicer.Get, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponse.SerializeToString, + ), + "GetNewDataSince": grpc.unary_stream_rpc_method_handler( + servicer.GetNewDataSince, + request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, + response_serializer=storage__pb2.GetNewDataSinceResponse.SerializeToString, + ), + "GetDataInInterval": grpc.unary_stream_rpc_method_handler( + servicer.GetDataInInterval, + request_deserializer=storage__pb2.GetDataInIntervalRequest.FromString, + response_serializer=storage__pb2.GetDataInIntervalResponse.SerializeToString, + ), + "GetDataPerWorker": grpc.unary_stream_rpc_method_handler( + servicer.GetDataPerWorker, + request_deserializer=storage__pb2.GetDataPerWorkerRequest.FromString, + response_serializer=storage__pb2.GetDataPerWorkerResponse.SerializeToString, + ), + "GetDatasetSize": grpc.unary_unary_rpc_method_handler( + servicer.GetDatasetSize, + request_deserializer=storage__pb2.GetDatasetSizeRequest.FromString, + response_serializer=storage__pb2.GetDatasetSizeResponse.SerializeToString, + ), + "CheckAvailability": grpc.unary_unary_rpc_method_handler( + servicer.CheckAvailability, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DatasetAvailableResponse.SerializeToString, + ), + "RegisterNewDataset": grpc.unary_unary_rpc_method_handler( + servicer.RegisterNewDataset, + request_deserializer=storage__pb2.RegisterNewDatasetRequest.FromString, + response_serializer=storage__pb2.RegisterNewDatasetResponse.SerializeToString, + ), + "GetCurrentTimestamp": grpc.unary_unary_rpc_method_handler( + servicer.GetCurrentTimestamp, + request_deserializer=storage__pb2.GetCurrentTimestampRequest.FromString, + response_serializer=storage__pb2.GetCurrentTimestampResponse.SerializeToString, + ), + "DeleteDataset": grpc.unary_unary_rpc_method_handler( + servicer.DeleteDataset, + request_deserializer=storage__pb2.DatasetAvailableRequest.FromString, + response_serializer=storage__pb2.DeleteDatasetResponse.SerializeToString, + ), + "DeleteData": grpc.unary_unary_rpc_method_handler( + servicer.DeleteData, + request_deserializer=storage__pb2.DeleteDataRequest.FromString, + response_serializer=storage__pb2.DeleteDataResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'modyn.storage.Storage', rpc_method_handlers) + "modyn.storage.Storage", rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class Storage(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Get(request, + def Get( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/Get', + "/modyn.storage.Storage/Get", storage__pb2.GetRequest.SerializeToString, storage__pb2.GetResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def GetNewDataSince(request, + def GetNewDataSince( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetNewDataSince', + "/modyn.storage.Storage/GetNewDataSince", storage__pb2.GetNewDataSinceRequest.SerializeToString, storage__pb2.GetNewDataSinceResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def GetDataInInterval(request, + def GetDataInInterval( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetDataInInterval', + "/modyn.storage.Storage/GetDataInInterval", storage__pb2.GetDataInIntervalRequest.SerializeToString, storage__pb2.GetDataInIntervalResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def GetDataPerWorker(request, + def GetDataPerWorker( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream(request, target, '/modyn.storage.Storage/GetDataPerWorker', + "/modyn.storage.Storage/GetDataPerWorker", storage__pb2.GetDataPerWorkerRequest.SerializeToString, storage__pb2.GetDataPerWorkerResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def GetDatasetSize(request, + def GetDatasetSize( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/GetDatasetSize', + "/modyn.storage.Storage/GetDatasetSize", storage__pb2.GetDatasetSizeRequest.SerializeToString, storage__pb2.GetDatasetSizeResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def CheckAvailability(request, + def CheckAvailability( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/CheckAvailability', + "/modyn.storage.Storage/CheckAvailability", storage__pb2.DatasetAvailableRequest.SerializeToString, storage__pb2.DatasetAvailableResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def RegisterNewDataset(request, + def RegisterNewDataset( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/RegisterNewDataset', + "/modyn.storage.Storage/RegisterNewDataset", storage__pb2.RegisterNewDatasetRequest.SerializeToString, storage__pb2.RegisterNewDatasetResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def GetCurrentTimestamp(request, + def GetCurrentTimestamp( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/GetCurrentTimestamp', - google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + "/modyn.storage.Storage/GetCurrentTimestamp", + storage__pb2.GetCurrentTimestampRequest.SerializeToString, storage__pb2.GetCurrentTimestampResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def DeleteDataset(request, + def DeleteDataset( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/DeleteDataset', + "/modyn.storage.Storage/DeleteDataset", storage__pb2.DatasetAvailableRequest.SerializeToString, storage__pb2.DeleteDatasetResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def DeleteData(request, + def DeleteData( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modyn.storage.Storage/DeleteData', + "/modyn.storage.Storage/DeleteData", storage__pb2.DeleteDataRequest.SerializeToString, storage__pb2.DeleteDataResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/modyn/storage/internal/grpc/grpc_server.py b/modyn/storage/internal/grpc/grpc_server.py deleted file mode 100644 index 6e57e596f..000000000 --- a/modyn/storage/internal/grpc/grpc_server.py +++ /dev/null @@ -1,27 +0,0 @@ -"""GRPC server context manager.""" - - -import logging -from typing import Any - -from modyn.common.grpc import GenericGRPCServer -from modyn.storage.internal.grpc.generated.storage_pb2_grpc import add_StorageServicer_to_server -from modyn.storage.internal.grpc.storage_grpc_servicer import StorageGRPCServicer - -logger = logging.getLogger(__name__) - - -class StorageGRPCServer(GenericGRPCServer): - """GRPC server context manager.""" - - @staticmethod - def callback(modyn_config: dict, server: Any) -> None: - add_StorageServicer_to_server(StorageGRPCServicer(modyn_config), server) - - def __init__(self, modyn_config: dict) -> None: - """Initialize the GRPC server. - - Args: - modyn_config (dict): Configuration of the storage module. - """ - super().__init__(modyn_config, modyn_config["storage"]["port"], StorageGRPCServer.callback) diff --git a/modyn/storage/internal/grpc/storage_grpc_servicer.py b/modyn/storage/internal/grpc/storage_grpc_servicer.py deleted file mode 100644 index a28de26f0..000000000 --- a/modyn/storage/internal/grpc/storage_grpc_servicer.py +++ /dev/null @@ -1,408 +0,0 @@ -"""Storage GRPC servicer.""" - -import logging -import os -import threading -from typing import Iterable, Tuple - -import grpc -from modyn.common.benchmark.stopwatch import Stopwatch -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.database.storage_database_utils import get_file_wrapper, get_filesystem_wrapper - -# pylint: disable-next=no-name-in-module -from modyn.storage.internal.grpc.generated.storage_pb2 import ( - DatasetAvailableRequest, - DatasetAvailableResponse, - DeleteDataRequest, - DeleteDataResponse, - DeleteDatasetResponse, - GetCurrentTimestampResponse, - GetDataInIntervalRequest, - GetDataInIntervalResponse, - GetDataPerWorkerRequest, - GetDataPerWorkerResponse, - GetDatasetSizeRequest, - GetDatasetSizeResponse, - GetNewDataSinceRequest, - GetNewDataSinceResponse, - GetRequest, - GetResponse, - RegisterNewDatasetRequest, - RegisterNewDatasetResponse, -) -from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageServicer -from modyn.utils.utils import current_time_millis, get_partition_for_worker -from sqlalchemy import and_, asc, select -from sqlalchemy.orm import Session - -logger = logging.getLogger(__name__) - - -class StorageGRPCServicer(StorageServicer): - """GRPC servicer for the storage module.""" - - def __init__(self, config: dict): - """Initialize the storage GRPC servicer. - - Args: - config (dict): Configuration of the storage module. - """ - self.modyn_config = config - self._sample_batch_size = self.modyn_config["storage"]["sample_batch_size"] - super().__init__() - - # pylint: disable-next=unused-argument,invalid-name,too-many-locals - def Get(self, request: GetRequest, context: grpc.ServicerContext) -> Iterable[GetResponse]: - """Return the data for the given keys. - - Args: - request (GetRequest): Request containing the dataset name and the keys. - context (grpc.ServicerContext): Context of the request. - - Returns: - Iterable[GetResponse]: Response containing the data for the given keys. - - Yields: - Iterator[Iterable[GetResponse]]: Response containing the data for the given keys. - """ - tid = threading.get_native_id() - pid = os.getpid() - logger.info(f"[{pid}][{tid}] Received request for {len(request.keys)} items.") - with StorageDatabaseConnection(self.modyn_config) as database: - session = database.session - - dataset: Dataset = session.query(Dataset).filter(Dataset.name == request.dataset_id).first() - if dataset is None: - logger.error(f"Dataset with name {request.dataset_id} does not exist.") - yield GetResponse() - return - - stopw = Stopwatch() - stopw.start("GetSamples") - samples: list[Sample] = ( - session.query(Sample) - .filter(and_(Sample.sample_id.in_(request.keys), Sample.dataset_id == dataset.dataset_id)) - .order_by(Sample.file_id) - .all() - ) - samples_time = stopw.stop() - logger.info(f"[{pid}][{tid}] Getting samples took {samples_time / 1000}s.") - - if len(samples) == 0: - logger.error("No samples found in the database.") - yield GetResponse() - return - - if len(samples) != len(request.keys): - logger.error("Not all keys were found in the database.") - not_found_keys = {s for s in request.keys if s not in [sample.sample_id for sample in samples]} - logger.error(f"Keys: {not_found_keys}") - - current_file_id = samples[0].file_id - current_file = ( - session.query(File) - .filter(File.file_id == current_file_id and File.dataset_id == dataset.dataset_id) - .first() - ) - samples_per_file: list[Tuple[int, int, int]] = [] - - # Iterate over all samples and group them by file, the samples are sorted by file_id (see query above) - for sample in samples: - if sample.file_id != current_file.file_id: - file_wrapper = get_file_wrapper( - dataset.file_wrapper_type, - current_file.path, - dataset.file_wrapper_config, - get_filesystem_wrapper(dataset.filesystem_wrapper_type, dataset.base_path), - ) - yield GetResponse( - samples=file_wrapper.get_samples_from_indices([index for index, _, _ in samples_per_file]), - keys=[sample_id for _, sample_id, _ in samples_per_file], - labels=[label for _, _, label in samples_per_file], - ) - samples_per_file = [(sample.index, sample.sample_id, sample.label)] - current_file_id = sample.file_id - current_file = ( - session.query(File) - .filter(File.file_id == current_file_id and File.dataset_id == dataset.dataset_id) - .first() - ) - else: - samples_per_file.append((sample.index, sample.sample_id, sample.label)) - file_wrapper = get_file_wrapper( - dataset.file_wrapper_type, - current_file.path, - dataset.file_wrapper_config, - get_filesystem_wrapper(dataset.filesystem_wrapper_type, dataset.base_path), - ) - yield GetResponse( - samples=file_wrapper.get_samples_from_indices([index for index, _, _ in samples_per_file]), - keys=[sample_id for _, sample_id, _ in samples_per_file], - labels=[label for _, _, label in samples_per_file], - ) - - # pylint: disable-next=unused-argument,invalid-name - def GetNewDataSince( - self, request: GetNewDataSinceRequest, context: grpc.ServicerContext - ) -> Iterable[GetNewDataSinceResponse]: - """Get all new data since the given timestamp. - - Returns: - GetNewDataSinceResponse: A response containing all external keys since the given timestamp. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - session = database.session - - dataset: Dataset = session.query(Dataset).filter(Dataset.name == request.dataset_id).first() - - if dataset is None: - logger.error(f"Dataset with name {request.dataset_id} does not exist.") - yield GetNewDataSinceResponse() - return - - timestamp = request.timestamp - - stmt = ( - select(Sample.sample_id, File.updated_at, Sample.label) - .join(File, Sample.file_id == File.file_id and Sample.dataset_id == File.dataset_id) - # Enables batching of results in chunks. - # See https://docs.sqlalchemy.org/en/20/orm/queryguide/api.html#orm-queryguide-yield-per - .execution_options(yield_per=self._sample_batch_size) - .filter(File.dataset_id == dataset.dataset_id) - .filter(File.updated_at >= timestamp) - .order_by(asc(File.updated_at), asc(Sample.sample_id)) - ) - - for batch in database.session.execute(stmt).partitions(): - if len(batch) > 0: - yield GetNewDataSinceResponse( - keys=[value[0] for value in batch], - timestamps=[value[1] for value in batch], - labels=[value[2] for value in batch], - ) - - def GetDataInInterval( - self, request: GetDataInIntervalRequest, context: grpc.ServicerContext - ) -> Iterable[GetDataInIntervalResponse]: - """Get all data in the given interval. - - Returns: - GetDataInIntervalResponse: A response containing all external keys in the given interval inclusive. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - session = database.session - - dataset: Dataset = session.query(Dataset).filter(Dataset.name == request.dataset_id).first() - - if dataset is None: - logger.error(f"Dataset with name {request.dataset_id} does not exist.") - yield GetDataInIntervalResponse() - return - - stmt = ( - select(Sample.sample_id, File.updated_at, Sample.label) - .join(File, Sample.file_id == File.file_id and Sample.dataset_id == File.dataset_id) - # Enables batching of results in chunks. - # See https://docs.sqlalchemy.org/en/20/orm/queryguide/api.html#orm-queryguide-yield-per - .execution_options(yield_per=self._sample_batch_size) - .filter(File.dataset_id == dataset.dataset_id) - .filter(File.updated_at >= request.start_timestamp) - .filter(File.updated_at <= request.end_timestamp) - .order_by(asc(File.updated_at), asc(Sample.sample_id)) - ) - - for batch in database.session.execute(stmt).partitions(): - if len(batch) > 0: - yield GetDataInIntervalResponse( - keys=[value[0] for value in batch], - timestamps=[value[1] for value in batch], - labels=[value[2] for value in batch], - ) - - # pylint: disable-next=unused-argument,invalid-name - def GetDataPerWorker( - self, request: GetDataPerWorkerRequest, context: grpc.ServicerContext - ) -> Iterable[GetDataPerWorkerResponse]: - """Get keys from the given dataset for a worker. - - Returns: - GetDataPerWorkerResponse: A response containing external keys from the dataset for the worker. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - session = database.session - - dataset: Dataset = session.query(Dataset).filter(Dataset.name == request.dataset_id).first() - - if dataset is None: - logger.error(f"Dataset with name {request.dataset_id} does not exist.") - yield GetNewDataSinceResponse() - return - - total_keys = session.query(Sample.sample_id).filter(Sample.dataset_id == dataset.dataset_id).count() - start_index, limit = get_partition_for_worker(request.worker_id, request.total_workers, total_keys) - - stmt = ( - select(Sample.sample_id) - # Enables batching of results in chunks. - # See https://docs.sqlalchemy.org/en/20/orm/queryguide/api.html#orm-queryguide-yield-per - .execution_options(yield_per=self._sample_batch_size) - .filter(Sample.dataset_id == dataset.dataset_id) - .order_by(Sample.sample_id) - .offset(start_index) - .limit(limit) - ) - - for batch in database.session.execute(stmt).partitions(): - if len(batch) > 0: - yield GetDataPerWorkerResponse(keys=[value[0] for value in batch]) - - # pylint: disable-next=unused-argument,invalid-name - def GetDatasetSize(self, request: GetDatasetSizeRequest, context: grpc.ServicerContext) -> GetDatasetSizeResponse: - """Get the total amount of keys for a given dataset. - - Returns: - GetDatasetSizeResponse: A response containing the amount of keys for a given dataset. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - session = database.session - - dataset: Dataset = session.query(Dataset).filter(Dataset.name == request.dataset_id).first() - - if dataset is None: - logger.error(f"Dataset with name {request.dataset_id} does not exist.") - return GetDatasetSizeResponse(success=False) - - total_keys = session.query(Sample.sample_id).filter(Sample.dataset_id == dataset.dataset_id).count() - return GetDatasetSizeResponse(success=True, num_keys=total_keys) - - # pylint: disable-next=unused-argument,invalid-name - def CheckAvailability( - self, request: DatasetAvailableRequest, context: grpc.ServicerContext - ) -> DatasetAvailableResponse: - """Check if a dataset is available in the database. - - Returns: - DatasetAvailableResponse: True if the dataset is available, False otherwise. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - session = database.session - - dataset: Dataset = session.query(Dataset).filter(Dataset.name == request.dataset_id).first() - - if dataset is None: - logger.error(f"Dataset with name {request.dataset_id} does not exist.") - return DatasetAvailableResponse(available=False) - - return DatasetAvailableResponse(available=True) - - # pylint: disable-next=unused-argument,invalid-name - def RegisterNewDataset( - self, request: RegisterNewDatasetRequest, context: grpc.ServicerContext - ) -> RegisterNewDatasetResponse: - """Register a new dataset in the database. - - Returns: - RegisterNewDatasetResponse: True if the dataset was successfully registered, False otherwise. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - success = database.add_dataset( - request.dataset_id, - request.base_path, - request.filesystem_wrapper_type, - request.file_wrapper_type, - request.description, - request.version, - request.file_wrapper_config, - request.ignore_last_timestamp, - request.file_watcher_interval, - ) - return RegisterNewDatasetResponse(success=success) - - # pylint: disable-next=unused-argument,invalid-name - def GetCurrentTimestamp(self, request: None, context: grpc.ServicerContext) -> GetCurrentTimestampResponse: - """Get the current timestamp. - - Returns: - GetCurrentTimestampResponse: The current timestamp. - """ - return GetCurrentTimestampResponse(timestamp=current_time_millis()) - - # pylint: disable-next=unused-argument,invalid-name - def DeleteDataset(self, request: DatasetAvailableRequest, context: grpc.ServicerContext) -> DeleteDatasetResponse: - """Delete a dataset from the database. - - Returns: - DeleteDatasetResponse: True if the dataset was successfully deleted, False otherwise. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - success = database.delete_dataset(request.dataset_id) - return DeleteDatasetResponse(success=success) - - def DeleteData(self, request: DeleteDataRequest, context: grpc.ServicerContext) -> DeleteDataResponse: - """Delete data from the database. - - Returns: - DeleteDataResponse: True if the data was successfully deleted, False otherwise. - """ - with StorageDatabaseConnection(self.modyn_config) as database: - session = database.session - dataset: Dataset = session.query(Dataset).filter(Dataset.name == request.dataset_id).first() - if dataset is None: - logger.error(f"Dataset with name {request.dataset_id} does not exist.") - return DeleteDataResponse(success=False) - - file_ids: list[Sample] = ( - session.query(Sample.file_id) - .filter(Sample.sample_id.in_(request.keys)) - .order_by(Sample.file_id) - .group_by(Sample.file_id) - .all() - ) - - for file_id in file_ids: - file_id = file_id[0] - file: File = session.query(File).filter(File.file_id == file_id).first() - if file is None: - logger.error(f"Could not find file for dataset {request.dataset_id}") - return DeleteDataResponse(success=False) - file_wrapper = get_file_wrapper( - dataset.file_wrapper_type, - file.path, - dataset.file_wrapper_config, - get_filesystem_wrapper(dataset.filesystem_wrapper_type, dataset.base_path), - ) - samples_to_delete = ( - session.query(Sample) - .filter(Sample.file_id == file.file_id) - .filter(Sample.sample_id.in_(request.keys)) - .all() - ) - file_wrapper.delete_samples(samples_to_delete) - file.number_of_samples -= len(samples_to_delete) - session.commit() - - session.query(Sample).filter(Sample.sample_id.in_(request.keys)).delete() - session.commit() - - self.remove_empty_files(session, dataset) - - return DeleteDataResponse(success=True) - - def remove_empty_files(self, session: Session, dataset: Dataset) -> None: - """Delete files that have no samples left.""" - files_to_delete = ( - session.query(File).filter(File.dataset_id == dataset.dataset_id).filter(File.number_of_samples == 0).all() - ) - for file in files_to_delete: - file_system_wrapper = get_filesystem_wrapper(dataset.filesystem_wrapper_type, dataset.base_path) - try: - file_system_wrapper.delete(file.path) - except FileNotFoundError: - logger.debug( - f"File {file.path} not found. Might have been deleted \ - already in the previous step of this method." - ) - session.query(File).filter(File.file_id == file.file_id).delete() - session.commit() diff --git a/modyn/storage/modyn-storage b/modyn/storage/modyn-storage index 68520410e..aa74eb1cd 100755 --- a/modyn/storage/modyn-storage +++ b/modyn/storage/modyn-storage @@ -1,3 +1,5 @@ #!/bin/bash MODYNPATH="$(python -c 'import modyn; print(modyn.__path__[0])')" -python -u $MODYNPATH/storage/storage_entrypoint.py "$@" \ No newline at end of file + +# run +$MODYNPATH/build/modyn/storage/modyn-storage "$@" diff --git a/modyn/storage/src/CMakeLists.txt b/modyn/storage/src/CMakeLists.txt new file mode 100644 index 000000000..fd95cd648 --- /dev/null +++ b/modyn/storage/src/CMakeLists.txt @@ -0,0 +1,138 @@ +set(MODYN_STORAGE_SOURCES + storage_server.cpp + internal/database/storage_database_connection.cpp + internal/database/cursor_handler.cpp + internal/file_watcher/file_watcher_watchdog.cpp + internal/file_watcher/file_watcher.cpp + internal/file_wrapper/binary_file_wrapper.cpp + internal/file_wrapper/csv_file_wrapper.cpp + internal/file_wrapper/file_wrapper_utils.cpp + internal/file_wrapper/single_sample_file_wrapper.cpp + internal/filesystem_wrapper/filesystem_wrapper_utils.cpp + internal/filesystem_wrapper/local_filesystem_wrapper.cpp + internal/grpc/storage_grpc_server.cpp + internal/grpc/storage_service_impl.cpp +) + +# Explicitly set all header files so that IDEs will recognize them as part of the project +set(MODYN_STORAGE_HEADERS + ../include/storage_server.hpp + ../include/internal/database/storage_database_connection.hpp + ../include/internal/database/cursor_handler.hpp + ../include/internal/file_watcher/file_watcher_watchdog.hpp + ../include/internal/file_watcher/file_watcher.hpp + ../include/internal/file_wrapper/file_wrapper.hpp + ../include/internal/file_wrapper/binary_file_wrapper.hpp + ../include/internal/file_wrapper/single_sample_file_wrapper.hpp + ../include/internal/file_wrapper/csv_file_wrapper.hpp + ../include/internal/file_wrapper/file_wrapper_utils.hpp + ../include/internal/filesystem_wrapper/filesystem_wrapper.hpp + ../include/internal/filesystem_wrapper/local_filesystem_wrapper.hpp + ../include/internal/filesystem_wrapper/filesystem_wrapper_utils.hpp + ../include/internal/grpc/storage_grpc_server.hpp + ../include/internal/grpc/storage_service_impl.hpp + ) + +set(MODYN-STORAGE_PROTOS + ../../protos/storage.proto +) + +add_library(modyn-storage-proto ${MODYN-STORAGE_PROTOS}) + +# We output the proto generated headers into the generated directory +# However, CMAKE_CURRENT_BINARY_DIR includes "src", such that the directory is [...]/src/../generated +# This is fine here, but then clang-tidy starts to match the auto-generated files, which we do not want +# Hence, we have to take the realpath of this directory. +# We have to generate the directory first to make realpath work. +set(PROTO_BINARY_DIR_REL "${CMAKE_CURRENT_BINARY_DIR}/../../../protos") +file(MAKE_DIRECTORY ${PROTO_BINARY_DIR_REL}) +execute_process(COMMAND realpath ${PROTO_BINARY_DIR_REL} OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE PROTO_BINARY_DIR) + +protobuf_generate( + TARGET modyn-storage-proto + OUT_VAR PROTO_GENERATED_FILES + IMPORT_DIRS ../../protos + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") + +protobuf_generate( + TARGET modyn-storage-proto + OUT_VAR PROTO_GENERATED_FILES + LANGUAGE grpc + GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc + PLUGIN "protoc-gen-grpc=\$" + # PLUGIN_OPTIONS "generate_mock_code=true" + IMPORT_DIRS ../../protos + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") + +message(STATUS "Proto generated files in ${PROTO_BINARY_DIR}: ${PROTO_GENERATED_FILES}") + +target_include_directories(modyn-storage-proto PUBLIC "$") + +target_link_libraries(modyn-storage-proto PUBLIC grpc++ grpc++_reflection) + +if (MODYN_USES_LOCAL_GRPC) + message(STATUS "Since we are using local gRPC, we need to fix linking issues. If you encounter errors, consider building gRPC from source.") + set(protobuf_ABSL_USED_TARGETS + absl::absl_check + absl::absl_log + absl::algorithm + absl::base + absl::bind_front + absl::bits + absl::btree + absl::cleanup + absl::cord + absl::core_headers + absl::debugging + absl::die_if_null + absl::dynamic_annotations + absl::flags + absl::flat_hash_map + absl::flat_hash_set + absl::function_ref + absl::hash + absl::layout + absl::log_initialize + absl::log_severity + absl::memory + absl::node_hash_map + absl::node_hash_set + absl::optional + absl::span + absl::status + absl::statusor + absl::strings + absl::synchronization + absl::time + absl::type_traits + absl::utility + absl::variant + absl::random_random + ) + + target_link_libraries(modyn-storage-proto PUBLIC protobuf::libprotobuf grpc_unsecure gpr libaddress_sorting.a libupb.a libcares.a libz.a utf8_range ${protobuf_ABSL_USED_TARGETS}) +else() + target_link_libraries(modyn-storage-proto PUBLIC libprotobuf) +endif() + +target_compile_options(modyn-storage-proto INTERFACE -Wno-unused-parameter -Wno-c++98-compat-extra-semi -Wno-conditional-uninitialized -Wno-documentation) + +target_sources(modyn-storage-library PRIVATE ${MODYN_STORAGE_HEADERS} ${MODYN_STORAGE_SOURCES}) +target_include_directories(modyn-storage-library PUBLIC ../include ${CMAKE_CURRENT_BINARY_DIR}/../clang-tidy-build/_deps/soci-src/include ${CMAKE_CURRENT_BINARY_DIR}/../build/_deps/soci-src/include ${CMAKE_CURRENT_BINARY_DIR}/_deps/include ${CMAKE_CURRENT_BINARY_DIR}/../_deps/include ${FETCHCONTENT_BASE_DIR}/include ${soci_SOURCE_DIR}/build/include ${PostgreSQL_INCLUDE_DIRS}) +target_compile_options(modyn-storage-library PRIVATE ${MODYN_COMPILE_OPTIONS}) + +target_link_libraries(modyn-storage-library PUBLIC modyn yaml-cpp ${PostgreSQL_LIBRARIES} soci_postgresql_static soci_sqlite3_static soci_core_static grpc++ grpc++_reflection modyn-storage-proto rapidcsv) + +message(STATUS "Current dir: ${CMAKE_CURRENT_SOURCE_DIR}") +message(STATUS "Current binary dir: ${CMAKE_CURRENT_BINARY_DIR}") + +target_compile_definitions(modyn-storage-library PRIVATE MODYN_BUILD_TYPE=\"${CMAKE_BUILD_TYPE}\") +target_compile_definitions(modyn-storage-library PRIVATE "MODYN_CMAKE_COMPILER=\"${MODYN_COMPILER_ENV} ${CMAKE_CXX_COMPILER}\"") +target_compile_definitions(modyn-storage-library PUBLIC ${MODYN_COMPILE_DEFINITIONS}) + +# This adds a `INCLUDE_DIRECTORIES` definition containing all include directories, separate by comma. +# The definition is set to PRIVATE, so it will not be exposed if the target is itself a dependency. +set(INCLUDE_EXPR "$") +set(INCLUDE_FILTER "$") +set(INCLUDE_JOINED "$") +target_compile_definitions(modyn-storage-library PRIVATE "INCLUDE_DIRECTORIES=\"${INCLUDE_JOINED}\"") \ No newline at end of file diff --git a/modyn/storage/src/internal/database/cursor_handler.cpp b/modyn/storage/src/internal/database/cursor_handler.cpp new file mode 100644 index 000000000..7ef97e3bb --- /dev/null +++ b/modyn/storage/src/internal/database/cursor_handler.cpp @@ -0,0 +1,115 @@ +#include "internal/database/cursor_handler.hpp" + +#include +#include +#include + +#include + +using namespace modyn::storage; + +std::vector CursorHandler::yield_per(const uint64_t number_of_rows_to_fetch) { + std::vector records; + check_cursor_initialized(); + + switch (driver_) { + case DatabaseDriver::POSTGRESQL: { + const std::string fetch_query = fmt::format("FETCH {} FROM {}", number_of_rows_to_fetch, cursor_name_); + ASSERT(number_of_rows_to_fetch <= std::numeric_limits::max(), + "Postgres can only accept up to MAX_INT rows per iteration"); + + PGresult* result = PQexec(postgresql_conn_, fetch_query.c_str()); + + if (PQresultStatus(result) != PGRES_TUPLES_OK) { + SPDLOG_ERROR("Cursor fetch failed: {}", PQerrorMessage(postgresql_conn_)); + PQclear(result); + return records; + } + + const auto rows = static_cast(PQntuples(result)); + records.resize(rows); + + for (uint64_t i = 0; i < rows; ++i) { + SampleRecord record{}; + const auto row_idx = static_cast(i); + record.id = std::stoll(PQgetvalue(result, row_idx, 0)); + if (number_of_columns_ > 1) { + record.column_1 = std::stoll(PQgetvalue(result, row_idx, 1)); + } + if (number_of_columns_ == 3) { + record.column_2 = std::stoll(PQgetvalue(result, row_idx, 2)); + } + + records[i] = record; + } + + PQclear(result); + return records; + break; + } + case DatabaseDriver::SQLITE3: { + uint64_t retrieved_rows = 0; + records.reserve(number_of_rows_to_fetch); + for (auto& row : *rs_) { + SampleRecord record{}; + record.id = StorageDatabaseConnection::get_from_row(row, 0); + if (number_of_columns_ > 1) { + record.column_1 = StorageDatabaseConnection::get_from_row(row, 1); + } + if (number_of_columns_ == 3) { + record.column_2 = StorageDatabaseConnection::get_from_row(row, 2); + } + records.push_back(record); + ++retrieved_rows; + if (retrieved_rows >= number_of_rows_to_fetch) { + break; + } + } + return records; + break; + } + default: + FAIL("Unsupported database driver"); + } +} + +void CursorHandler::check_cursor_initialized() { + if (rs_ == nullptr && postgresql_conn_ == nullptr) { + SPDLOG_ERROR("Cursor not initialized"); + throw std::runtime_error("Cursor not initialized"); + } +} + +void CursorHandler::close_cursor() { + if (!open_) { + return; + } + + switch (driver_) { + case DatabaseDriver::POSTGRESQL: { + auto* postgresql_session_backend = static_cast(session_.get_backend()); + if (postgresql_session_backend == nullptr) { + SPDLOG_ERROR("Cannot close cursor due to session being nullptr!"); + return; + } + + PGconn* conn = postgresql_session_backend->conn_; + + const std::string close_query = "CLOSE " + cursor_name_; + PGresult* result = PQexec(conn, close_query.c_str()); + + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + SPDLOG_ERROR(fmt::format("Cursor close failed: {}", PQerrorMessage(conn))); + } + + PQclear(result); + break; + } + case DatabaseDriver::SQLITE3: + break; + default: + FAIL("Unsupported database driver"); + } + + open_ = false; +} \ No newline at end of file diff --git a/modyn/storage/src/internal/database/sql/PostgreSQLDataset.sql b/modyn/storage/src/internal/database/sql/PostgreSQLDataset.sql new file mode 100644 index 000000000..9b9f10680 --- /dev/null +++ b/modyn/storage/src/internal/database/sql/PostgreSQLDataset.sql @@ -0,0 +1,13 @@ +R"(CREATE TABLE IF NOT EXISTS datasets ( + dataset_id SERIAL PRIMARY KEY, + name VARCHAR(80) NOT NULL, + description VARCHAR(120), + version VARCHAR(80), + filesystem_wrapper_type INTEGER, + file_wrapper_type INTEGER, + base_path VARCHAR(120) NOT NULL, + file_wrapper_config VARCHAR(240), + last_timestamp BIGINT NOT NULL, + ignore_last_timestamp BOOLEAN NOT NULL DEFAULT FALSE, + file_watcher_interval BIGINT NOT NULL DEFAULT 5 +))" \ No newline at end of file diff --git a/modyn/storage/src/internal/database/sql/PostgreSQLFile.sql b/modyn/storage/src/internal/database/sql/PostgreSQLFile.sql new file mode 100644 index 000000000..65b2830b8 --- /dev/null +++ b/modyn/storage/src/internal/database/sql/PostgreSQLFile.sql @@ -0,0 +1,11 @@ +R"(CREATE TABLE IF NOT EXISTS files ( + file_id BIGSERIAL PRIMARY KEY, + dataset_id INTEGER NOT NULL, + path VARCHAR(120) NOT NULL, + updated_at BIGINT, + number_of_samples INTEGER +); + +CREATE INDEX IF NOT EXISTS files_dataset_id ON files (dataset_id); + +CREATE INDEX IF NOT EXISTS files_updated_at ON files (updated_at);)" \ No newline at end of file diff --git a/modyn/storage/src/internal/database/sql/PostgreSQLSample.sql b/modyn/storage/src/internal/database/sql/PostgreSQLSample.sql new file mode 100644 index 000000000..8329ec943 --- /dev/null +++ b/modyn/storage/src/internal/database/sql/PostgreSQLSample.sql @@ -0,0 +1,9 @@ +R"(CREATE TABLE IF NOT EXISTS samples ( + sample_id BIGSERIAL NOT NULL, + dataset_id INTEGER NOT NULL, + file_id INTEGER, + sample_index BIGINT, + label BIGINT, + PRIMARY KEY (dataset_id, sample_id) + +) PARTITION BY LIST (dataset_id))" \ No newline at end of file diff --git a/modyn/storage/src/internal/database/sql/SQLiteDataset.sql b/modyn/storage/src/internal/database/sql/SQLiteDataset.sql new file mode 100644 index 000000000..91b55e351 --- /dev/null +++ b/modyn/storage/src/internal/database/sql/SQLiteDataset.sql @@ -0,0 +1,13 @@ +R"(CREATE TABLE IF NOT EXISTS datasets ( + dataset_id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(80) NOT NULL, + description VARCHAR(120), + version VARCHAR(80), + filesystem_wrapper_type INTEGER, + file_wrapper_type INTEGER, + base_path VARCHAR(120) NOT NULL, + file_wrapper_config VARCHAR(240), + last_timestamp BIGINT NOT NULL, + ignore_last_timestamp BOOLEAN NOT NULL DEFAULT FALSE, + file_watcher_interval BIGINT NOT NULL DEFAULT 5 +))" \ No newline at end of file diff --git a/modyn/storage/src/internal/database/sql/SQLiteFile.sql b/modyn/storage/src/internal/database/sql/SQLiteFile.sql new file mode 100644 index 000000000..2727e8586 --- /dev/null +++ b/modyn/storage/src/internal/database/sql/SQLiteFile.sql @@ -0,0 +1,7 @@ +R"(CREATE TABLE IF NOT EXISTS files ( + file_id INTEGER PRIMARY KEY AUTOINCREMENT, + dataset_id INTEGER NOT NULL, + path VARCHAR(120) NOT NULL, + updated_at BIGINT, + number_of_samples INTEGER +))" \ No newline at end of file diff --git a/modyn/storage/src/internal/database/sql/SQLiteSample.sql b/modyn/storage/src/internal/database/sql/SQLiteSample.sql new file mode 100644 index 000000000..9fea0218c --- /dev/null +++ b/modyn/storage/src/internal/database/sql/SQLiteSample.sql @@ -0,0 +1,7 @@ +R"(CREATE TABLE IF NOT EXISTS samples ( + sample_id INTEGER PRIMARY KEY AUTOINCREMENT, + dataset_id INTEGER NOT NULL, + file_id INTEGER, + sample_index BIGINT, + label BIGINT +))" \ No newline at end of file diff --git a/modyn/storage/src/internal/database/storage_database_connection.cpp b/modyn/storage/src/internal/database/storage_database_connection.cpp new file mode 100644 index 000000000..3adc7b0b6 --- /dev/null +++ b/modyn/storage/src/internal/database/storage_database_connection.cpp @@ -0,0 +1,245 @@ +#include "internal/database/storage_database_connection.hpp" + +#include +#include + +#include +#include + +#include "modyn/utils/utils.hpp" +#include "soci/postgresql/soci-postgresql.h" +#include "soci/sqlite3/soci-sqlite3.h" + +using namespace modyn::storage; + +soci::session StorageDatabaseConnection::get_session() const { + const std::string connection_string = + fmt::format("dbname={} user={} password={} host={} port={}", database_, username_, password_, host_, port_); + soci::connection_parameters parameters; + + switch (drivername_) { + case DatabaseDriver::POSTGRESQL: + parameters = soci::connection_parameters(soci::postgresql, connection_string); + break; + case DatabaseDriver::SQLITE3: + parameters = soci::connection_parameters(soci::sqlite3, connection_string); + break; + default: + FAIL("Unsupported database driver"); + } + return soci::session(parameters); +} + +void StorageDatabaseConnection::create_tables() const { + soci::session session = get_session(); + + std::string dataset_table_sql; + std::string file_table_sql; + std::string sample_table_sql; + switch (drivername_) { + case DatabaseDriver::POSTGRESQL: + dataset_table_sql = +#include "sql/PostgreSQLDataset.sql" + ; + file_table_sql = +#include "sql/PostgreSQLFile.sql" + ; + sample_table_sql = +#include "sql/PostgreSQLSample.sql" + ; + break; + case DatabaseDriver::SQLITE3: + dataset_table_sql = +#include "sql/SQLiteDataset.sql" + ; + file_table_sql = +#include "sql/SQLiteFile.sql" + ; + sample_table_sql = +#include "sql/SQLiteSample.sql" + ; + break; + default: + FAIL("Unsupported database driver"); + } + session << dataset_table_sql; + + session << file_table_sql; + + session << sample_table_sql; + + if (drivername_ == DatabaseDriver::POSTGRESQL && sample_table_unlogged_) { + session << "ALTER TABLE samples SET UNLOGGED"; + } + + session.close(); +} + +bool StorageDatabaseConnection::add_dataset(const std::string& name, const std::string& base_path, + const FilesystemWrapperType& filesystem_wrapper_type, + const FileWrapperType& file_wrapper_type, const std::string& description, + const std::string& version, const std::string& file_wrapper_config, + const bool ignore_last_timestamp, + const int64_t file_watcher_interval) const { + soci::session session = get_session(); + + auto filesystem_wrapper_type_int = static_cast(filesystem_wrapper_type); + auto file_wrapper_type_int = static_cast(file_wrapper_type); + std::string boolean_string = ignore_last_timestamp ? "true" : "false"; + + if (get_dataset_id(name) != -1) { + SPDLOG_ERROR("Dataset {} already exists", name); + return false; + } + switch (drivername_) { + case DatabaseDriver::POSTGRESQL: + try { + session << "INSERT INTO datasets (name, base_path, filesystem_wrapper_type, " + "file_wrapper_type, description, version, file_wrapper_config, " + "ignore_last_timestamp, file_watcher_interval, last_timestamp) " + "VALUES (:name, " + ":base_path, :filesystem_wrapper_type, :file_wrapper_type, " + ":description, :version, :file_wrapper_config, " + ":ignore_last_timestamp, :file_watcher_interval, 0)", + soci::use(name), soci::use(base_path), soci::use(filesystem_wrapper_type_int), + soci::use(file_wrapper_type_int), soci::use(description), soci::use(version), + soci::use(file_wrapper_config), soci::use(boolean_string), soci::use(file_watcher_interval); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error adding dataset: {}", e.what()); + return false; + } + break; + case DatabaseDriver::SQLITE3: + session << "INSERT INTO datasets (name, base_path, filesystem_wrapper_type, " + "file_wrapper_type, description, version, file_wrapper_config, " + "ignore_last_timestamp, file_watcher_interval, last_timestamp) " + "VALUES (:name, " + ":base_path, :filesystem_wrapper_type, :file_wrapper_type, " + ":description, :version, :file_wrapper_config, " + ":ignore_last_timestamp, :file_watcher_interval, 0)", + soci::use(name), soci::use(base_path), soci::use(filesystem_wrapper_type_int), + soci::use(file_wrapper_type_int), soci::use(description), soci::use(version), soci::use(file_wrapper_config), + soci::use(boolean_string), soci::use(file_watcher_interval); + break; + default: + SPDLOG_ERROR("Error adding dataset: Unsupported database driver."); + return false; + } + + // Create partition table for samples + if (!add_sample_dataset_partition(name)) { + FAIL("Partition creation failed."); + } + session.close(); + + return true; +} + +int64_t StorageDatabaseConnection::get_dataset_id(const std::string& name) const { + soci::session session = get_session(); + + int64_t dataset_id = -1; + session << "SELECT dataset_id FROM datasets WHERE name = :name", soci::into(dataset_id), soci::use(name); + session.close(); + + return dataset_id; +} + +DatabaseDriver StorageDatabaseConnection::get_drivername(const YAML::Node& config) { + ASSERT(config["storage"]["database"], "No database configuration found"); + + const auto drivername = config["storage"]["database"]["drivername"].as(); + if (drivername == "postgresql") { + return DatabaseDriver::POSTGRESQL; + } + if (drivername == "sqlite3") { + return DatabaseDriver::SQLITE3; + } + + FAIL("Unsupported database driver: " + drivername); +} + +bool StorageDatabaseConnection::delete_dataset(const std::string& name, const int64_t dataset_id) const { + soci::session session = get_session(); + + // Delete all samples for this dataset + try { + session << "DELETE FROM samples WHERE dataset_id = :dataset_id", soci::use(dataset_id); + } catch (const soci::soci_error& e) { + SPDLOG_ERROR("Error deleting samples for dataset {}: {}", name, e.what()); + return false; + } + + // Delete all files for this dataset + try { + session << "DELETE FROM files WHERE dataset_id = :dataset_id", soci::use(dataset_id); + } catch (const soci::soci_error& e) { + SPDLOG_ERROR("Error deleting files for dataset {}: {}", name, e.what()); + return false; + } + + // Delete the dataset + try { + session << "DELETE FROM datasets WHERE name = :name", soci::use(name); + } catch (const soci::soci_error& e) { + SPDLOG_ERROR("Error deleting dataset {}: {}", name, e.what()); + return false; + } + + session.close(); + + return true; +} + +bool StorageDatabaseConnection::add_sample_dataset_partition(const std::string& dataset_name) const { + soci::session session = get_session(); + int64_t dataset_id = get_dataset_id(dataset_name); + if (dataset_id == -1) { + SPDLOG_ERROR("Dataset {} not found", dataset_name); + return false; + } + switch (drivername_) { + case DatabaseDriver::POSTGRESQL: { + std::string dataset_partition_table_name = "samples__did" + std::to_string(dataset_id); + try { + session << fmt::format( + "CREATE TABLE IF NOT EXISTS {} " + "PARTITION OF samples " + "FOR VALUES IN ({}) " + "PARTITION BY HASH (sample_id)", + dataset_partition_table_name, dataset_id); + } catch (const soci::soci_error& e) { + SPDLOG_ERROR("Error creating partition table for dataset {}: {}", dataset_name, e.what()); + return false; + } + + try { + for (int64_t i = 0; i < hash_partition_modulus_; i++) { + std::string hash_partition_name = dataset_partition_table_name + "_part" + std::to_string(i); + session << fmt::format( + "CREATE TABLE IF NOT EXISTS {} " + "PARTITION OF {} " + "FOR VALUES WITH (modulus {}, REMAINDER {})", + hash_partition_name, dataset_partition_table_name, hash_partition_modulus_, i); + } + } catch (const soci::soci_error& e) { + SPDLOG_ERROR("Error creating hash partitions for dataset {}: {}", dataset_name, e.what()); + return false; + } + break; + } + case DatabaseDriver::SQLITE3: { + SPDLOG_INFO( + "Skipping partition creation for dataset {}, not supported for " + "driver.", + dataset_name); + break; + } + default: + FAIL("Unsupported database driver."); + } + + session.close(); + + return true; +} diff --git a/modyn/storage/src/internal/file_watcher/file_watcher.cpp b/modyn/storage/src/internal/file_watcher/file_watcher.cpp new file mode 100644 index 000000000..ddffa02b1 --- /dev/null +++ b/modyn/storage/src/internal/file_watcher/file_watcher.cpp @@ -0,0 +1,442 @@ +#include "internal/file_watcher/file_watcher.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "internal/file_wrapper/file_wrapper_utils.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper_utils.hpp" + +using namespace modyn::storage; + +/* + * Checks if the file is to be inserted into the database. Assumes file extension has already been validated. + * + * Files to be inserted into the database are defined as files that adhere to the following rules: + * - The file is not already in the database. + * - If we are not ignoring the last modified timestamp, the file has been modified since the last check. + */ +bool FileWatcher::check_file_for_insertion(const std::string& file_path, bool ignore_last_timestamp, int64_t timestamp, + int64_t dataset_id, + const std::shared_ptr& filesystem_wrapper, + soci::session& session) { + if (file_path.empty()) { + return false; + } + + int64_t file_id = -1; + session << "SELECT file_id FROM files WHERE path = :file_path AND dataset_id = :dataset_id", soci::into(file_id), + soci::use(file_path), soci::use(dataset_id); + + if (file_id == -1) { + if (ignore_last_timestamp) { + return true; + } + try { + const int64_t& modified_time = filesystem_wrapper->get_modified_time(file_path); + /* if (modified_time <= timestamp) { + SPDLOG_INFO("File {} has modified time {}, timestamp is {}, discarding", file_path, modified_time, timestamp); + } */ + return modified_time >= timestamp || timestamp == 0; + } catch (const std::exception& e) { + SPDLOG_ERROR(fmt::format( + "Error while checking modified time of file {}. It could be that a deletion request is currently running: {}", + file_path, e.what())); + return false; + } + } /* else { + SPDLOG_INFO("File {} is already known under id {}, discarding", file_path, file_id); + } */ + return false; +} + +/* + * Searches for new files in the given directory and updates the files in the database. + * + * Iterates over all files in the directory and depending on whether we are multi or single threaded, either handles the + * file paths directly or spawns new threads to handle the file paths. + * + * Each thread spawned will handle an equal share of the files in the directory. + */ +void FileWatcher::search_for_new_files_in_directory(const std::string& directory_path, int64_t timestamp) { + std::vector file_paths = + filesystem_wrapper->list(directory_path, /*recursive=*/true, data_file_extension_); + SPDLOG_INFO("Found {} files in total", file_paths.size()); + + if (file_paths.empty()) { + return; + } + + if (disable_multithreading_) { + std::atomic exception_thrown = false; + FileWatcher::handle_file_paths(file_paths.begin(), file_paths.end(), file_wrapper_type_, timestamp, + filesystem_wrapper_type_, dataset_id_, &file_wrapper_config_node_, &config_, + sample_dbinsertion_batchsize_, force_fallback_, &exception_thrown); + if (exception_thrown.load()) { + *stop_file_watcher = true; + } + } else { + const auto chunk_size = static_cast(file_paths.size()) / static_cast(insertion_threads_); + SPDLOG_INFO("Inserting {} files per thread (total = {} threads)", chunk_size, insertion_threads_); + + for (int16_t i = 0; i < insertion_threads_; ++i) { + SPDLOG_INFO("Spawning thread {}/{} for insertion.", i + 1, insertion_threads_); + // NOLINTNEXTLINE(modernize-use-auto): Let's be explicit about the iterator type here + const std::vector::iterator begin = file_paths.begin() + static_cast(i) * chunk_size; + // NOLINTNEXTLINE(modernize-use-auto): Let's be explicit about the iterator type here + const std::vector::iterator end = + (i < insertion_threads_ - 1) ? (begin + chunk_size) : file_paths.end(); + + std::atomic* exception_thrown = &insertion_thread_exceptions_.at(i); + exception_thrown->store(false); + + insertion_thread_pool_.emplace_back(FileWatcher::handle_file_paths, begin, end, file_wrapper_type_, timestamp, + filesystem_wrapper_type_, dataset_id_, &file_wrapper_config_node_, &config_, + sample_dbinsertion_batchsize_, force_fallback_, exception_thrown); + } + + uint16_t index = 0; + for (auto& thread : insertion_thread_pool_) { + // handle if any thread throws an exception + if (insertion_thread_exceptions_[index].load()) { + *stop_file_watcher = true; + break; + } + index++; + if (thread.joinable()) { + thread.join(); + } + } + insertion_thread_pool_.clear(); + } +} + +/* + * Updating the files in the database for the given directory with the last inserted timestamp. + */ +void FileWatcher::seek_dataset(soci::session& session) { + int64_t last_timestamp = -1; + + session << "SELECT last_timestamp FROM datasets " + "WHERE dataset_id = :dataset_id", + soci::into(last_timestamp), soci::use(dataset_id_); + + SPDLOG_INFO("Seeking dataset {} with last timestamp = {}", dataset_id_, last_timestamp); + + search_for_new_files_in_directory(dataset_path_, last_timestamp); +} + +/* + * Seeking the dataset and updating the last inserted timestamp. + */ +void FileWatcher::seek(soci::session& session) { + seek_dataset(session); + + int64_t last_timestamp = -1; + session << "SELECT updated_at FROM files WHERE dataset_id = :dataset_id ORDER " + "BY updated_at DESC LIMIT 1", + soci::into(last_timestamp), soci::use(dataset_id_); + + if (last_timestamp > 0) { + session << "UPDATE datasets SET last_timestamp = :last_timestamp WHERE dataset_id = " + ":dataset_id", + soci::use(last_timestamp), soci::use(dataset_id_); + } +} + +void FileWatcher::run() { + soci::session session = storage_database_connection_.get_session(); + + int64_t file_watcher_interval = -1; + session << "SELECT file_watcher_interval FROM datasets WHERE dataset_id = :dataset_id", + soci::into(file_watcher_interval), soci::use(dataset_id_); + + if (file_watcher_interval == -1) { + SPDLOG_ERROR("Failed to get file watcher interval"); + *stop_file_watcher = true; + return; + } + + while (true) { + try { + seek(session); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error while seeking dataset: {}", e.what()); + stop_file_watcher->store(true); + } + if (stop_file_watcher->load()) { + SPDLOG_INFO("FileWatcher for dataset {} is exiting.", dataset_id_); + break; + } + std::this_thread::sleep_for(std::chrono::seconds(file_watcher_interval)); + if (stop_file_watcher->load()) { + SPDLOG_INFO("FileWatcher for dataset {} is exiting.", dataset_id_); + break; + } + } + + session.close(); +} + +void FileWatcher::handle_file_paths(const std::vector::iterator file_paths_begin, + const std::vector::iterator file_paths_end, + const FileWrapperType file_wrapper_type, int64_t timestamp, + const FilesystemWrapperType filesystem_wrapper_type, const int64_t dataset_id, + const YAML::Node* file_wrapper_config, const YAML::Node* config, + const int64_t sample_dbinsertion_batchsize, const bool force_fallback, + std::atomic* exception_thrown) { + try { + SPDLOG_INFO("Hi, this is handle_file_paths. Checking {} items", file_paths_end - file_paths_begin); + if (file_paths_begin >= file_paths_end) { + return; + } + + const StorageDatabaseConnection storage_database_connection(*config); + soci::session session = storage_database_connection.get_session(); + + auto filesystem_wrapper = get_filesystem_wrapper(filesystem_wrapper_type); + + int ignore_last_timestamp = 0; + session << "SELECT ignore_last_timestamp FROM datasets WHERE dataset_id = :dataset_id", + soci::into(ignore_last_timestamp), soci::use(dataset_id); + + // 1. Batch files into chunks + + const int64_t num_paths = file_paths_end - file_paths_begin; + int64_t num_chunks = num_paths / sample_dbinsertion_batchsize; + + if (num_paths % sample_dbinsertion_batchsize != 0) { + ++num_chunks; + } + + std::vector unknown_files; + + for (int64_t i = 0; i < num_chunks; ++i) { + SPDLOG_INFO("Handling chunk {}/{}", i + 1, num_chunks); + auto start_it = file_paths_begin + i * sample_dbinsertion_batchsize; + auto end_it = i < num_chunks - 1 ? start_it + sample_dbinsertion_batchsize : file_paths_end; + std::vector chunk_paths(start_it, end_it); + const std::string known_files_query = fmt::format( + "SELECT path FROM files WHERE path IN ('{}') AND dataset_id = :dataset_id", fmt::join(chunk_paths, "','")); + std::vector known_paths(sample_dbinsertion_batchsize); + // SPDLOG_INFO("Chunk: {}/{} prepared query", i + 1, num_chunks); + session << known_files_query, soci::into(known_paths), soci::use(dataset_id); + // SPDLOG_INFO("Chunk: {}/{} executed query", i + 1, num_chunks); + std::unordered_set known_paths_set(known_paths.begin(), known_paths.end()); + // SPDLOG_INFO("Chunk: {}/{} prepared hashtable", i + 1, num_chunks); + + std::copy_if(chunk_paths.begin(), chunk_paths.end(), std::back_inserter(unknown_files), + [&known_paths_set](const std::string& file_path) { return !known_paths_set.contains(file_path); }); + } + SPDLOG_INFO("Found {} unknown files!", unknown_files.size()); + std::vector files_for_insertion; + + if (ignore_last_timestamp == 0) { + files_for_insertion.reserve(unknown_files.size()); + auto logger = spdlog::default_logger(); // we cannot use SPDLOG_ERROR inside the lambda below + + std::copy_if(unknown_files.begin(), unknown_files.end(), std::back_inserter(files_for_insertion), + [&filesystem_wrapper, ×tamp, &logger](const std::string& file_path) { + try { + const int64_t& modified_time = filesystem_wrapper->get_modified_time(file_path); + return modified_time >= timestamp || timestamp == 0; + } catch (const std::exception& mod_e) { + logger->error( + fmt::format("Error while checking modified time of file {}. It could be that a deletion " + "request is currently running: {}", + file_path, mod_e.what())); + return false; + } + }); + } else { + files_for_insertion = unknown_files; + } + + unknown_files.clear(); + unknown_files.shrink_to_fit(); + + if (!files_for_insertion.empty()) { + SPDLOG_INFO("Found {} files for insertion!", files_for_insertion.size()); + DatabaseDriver database_driver = storage_database_connection.get_drivername(); + handle_files_for_insertion(files_for_insertion, file_wrapper_type, dataset_id, *file_wrapper_config, + sample_dbinsertion_batchsize, force_fallback, session, database_driver, + filesystem_wrapper); + } + session.close(); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error while handling file paths: {}", e.what()); + exception_thrown->store(true); + } +} + +void FileWatcher::handle_files_for_insertion(std::vector& files_for_insertion, + const FileWrapperType& file_wrapper_type, const int64_t dataset_id, + const YAML::Node& file_wrapper_config, + const int64_t sample_dbinsertion_batchsize, const bool force_fallback, + soci::session& session, DatabaseDriver& database_driver, + const std::shared_ptr& filesystem_wrapper) { + const std::string file_path = files_for_insertion.front(); + std::vector file_samples; + auto file_wrapper = get_file_wrapper(file_path, file_wrapper_type, file_wrapper_config, filesystem_wrapper); + + int64_t current_file_samples_to_be_inserted = 0; + for (const auto& file_path : files_for_insertion) { + file_wrapper->set_file_path(file_path); + const int64_t file_id = + insert_file(file_path, dataset_id, filesystem_wrapper, file_wrapper, session, database_driver); + + if (file_id == -1) { + SPDLOG_ERROR("Failed to insert file into database"); + continue; + } + + const std::vector labels = file_wrapper->get_all_labels(); + + int32_t index = 0; + for (const auto& label : labels) { + if (current_file_samples_to_be_inserted == sample_dbinsertion_batchsize) { + insert_file_samples(file_samples, dataset_id, force_fallback, session, database_driver); + file_samples.clear(); + current_file_samples_to_be_inserted = 0; + } + file_samples.push_back({file_id, index, label}); + index++; + current_file_samples_to_be_inserted++; + } + } + + if (!file_samples.empty()) { + insert_file_samples(file_samples, dataset_id, force_fallback, session, database_driver); + } +} + +int64_t FileWatcher::insert_file(const std::string& file_path, const int64_t dataset_id, + const std::shared_ptr& filesystem_wrapper, + const std::unique_ptr& file_wrapper, soci::session& session, + DatabaseDriver& database_driver) { + uint64_t number_of_samples = 0; + number_of_samples = file_wrapper->get_number_of_samples(); + const int64_t modified_time = filesystem_wrapper->get_modified_time(file_path); + int64_t file_id = -1; + + // soci::session::get_last_insert_id() is not supported by postgresql, so we need to use a different query. + if (database_driver == DatabaseDriver::SQLITE3) { + file_id = insert_file(file_path, dataset_id, session, number_of_samples, modified_time); + } else if (database_driver == DatabaseDriver::POSTGRESQL) { + file_id = insert_file_using_returning_statement(file_path, dataset_id, session, number_of_samples, modified_time); + } + return file_id; +} + +int64_t FileWatcher::insert_file(const std::string& file_path, const int64_t dataset_id, soci::session& session, + uint64_t number_of_samples, int64_t modified_time) { + session << "INSERT INTO files (dataset_id, path, number_of_samples, " + "updated_at) VALUES (:dataset_id, :path, " + ":updated_at, :number_of_samples)", + soci::use(dataset_id), soci::use(file_path), soci::use(modified_time), soci::use(number_of_samples); + + long long file_id = -1; // NOLINT google-runtime-int (Linux otherwise complains about the following call) + if (!session.get_last_insert_id("files", file_id)) { + SPDLOG_ERROR("Failed to insert file into database"); + return -1; + } + return file_id; +} + +int64_t FileWatcher::insert_file_using_returning_statement(const std::string& file_path, const int64_t dataset_id, + soci::session& session, uint64_t number_of_samples, + int64_t modified_time) { + // SPDLOG_INFO( + // fmt::format("Inserting file {} with {} samples for dataset {}", file_path, number_of_samples, dataset_id)); + int64_t file_id = -1; + session << "INSERT INTO files (dataset_id, path, number_of_samples, " + "updated_at) VALUES (:dataset_id, :path, " + ":number_of_samples, :updated_at) RETURNING file_id", + soci::use(dataset_id), soci::use(file_path), soci::use(number_of_samples), soci::use(modified_time), + soci::into(file_id); + // SPDLOG_INFO(fmt::format("Inserted file {} into file ID {}", file_path, file_id)); + + if (file_id == -1) { + SPDLOG_ERROR("Failed to insert file into database"); + return -1; + } + return file_id; +} + +void FileWatcher::insert_file_samples(const std::vector& file_samples, const int64_t dataset_id, + const bool force_fallback, soci::session& session, + DatabaseDriver& database_driver) { + if (force_fallback) { + return fallback_insertion(file_samples, dataset_id, session); + } + + switch (database_driver) { + case DatabaseDriver::POSTGRESQL: + return postgres_copy_insertion(file_samples, dataset_id, session); + case DatabaseDriver::SQLITE3: + return fallback_insertion(file_samples, dataset_id, session); + default: + FAIL("Unsupported database driver"); + } +} + +/* + * Inserts the file frame into the database using the optimized postgresql copy command. + * + * The data is expected in a vector of FileFrame which is defined as file_id, sample_index, label. + */ +void FileWatcher::postgres_copy_insertion(const std::vector& file_samples, const int64_t dataset_id, + soci::session& session) { + SPDLOG_INFO(fmt::format("Doing copy insertion for {} samples", file_samples.size())); + auto* postgresql_session_backend = static_cast(session.get_backend()); + PGconn* conn = postgresql_session_backend->conn_; + + const std::string copy_query = + "COPY samples(dataset_id,file_id,sample_index,label) FROM STDIN WITH (DELIMITER ',', FORMAT CSV)"; + + PQexec(conn, copy_query.c_str()); + // put the data into the buffer + std::stringstream ss; + for (const auto& frame : file_samples) { + ss << fmt::format("{},{},{},{}\n", dataset_id, frame.file_id, frame.index, frame.label); + } + + PQputline(conn, ss.str().c_str()); + PQputline(conn, "\\.\n"); // Note the application must explicitly send the two characters "\." on a final line to + // indicate to the backend that it has finished sending its data. + // https://web.mit.edu/cygwin/cygwin_v1.3.2/usr/doc/postgresql-7.1.2/html/libpq-copy.html + PQendcopy(conn); + SPDLOG_INFO(fmt::format("Copy insertion for {} samples finished.", file_samples.size())); +} + +/* + * Inserts the file frame into the database using the fallback method. + * + * The data is expected in a vector of FileFrame structs which is defined as file_id, sample_index, label. + * It is then inserted into the database using a prepared statement. + */ +void FileWatcher::fallback_insertion(const std::vector& file_samples, const int64_t dataset_id, + soci::session& session) { + // Prepare query + std::string query = "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES "; + + if (!file_samples.empty()) { + for (auto frame = file_samples.cbegin(); frame != std::prev(file_samples.cend()); ++frame) { + query += fmt::format("({},{},{},{}),", dataset_id, frame->file_id, frame->index, frame->label); + } + + // Add the last frame without a comma + const auto& last_frame = file_samples.back(); + query += fmt::format("({},{},{},{})", dataset_id, last_frame.file_id, last_frame.index, last_frame.label); + + session << query; + } +} diff --git a/modyn/storage/src/internal/file_watcher/file_watcher_watchdog.cpp b/modyn/storage/src/internal/file_watcher/file_watcher_watchdog.cpp new file mode 100644 index 000000000..5da735645 --- /dev/null +++ b/modyn/storage/src/internal/file_watcher/file_watcher_watchdog.cpp @@ -0,0 +1,181 @@ +#include "internal/file_watcher/file_watcher_watchdog.hpp" + +#include + +#include +#include +#include + +#include "soci/soci.h" + +using namespace modyn::storage; + +/* + * Start a new FileWatcher thread for the given dataset + * + * Also add the FileWatcher thread to the map of FileWatcher threads, we propegate the retries value to the map + * that way we can keep track of how many retries are left for a given dataset + */ +void FileWatcherWatchdog::start_file_watcher_thread(int64_t dataset_id) { + // Start a new child thread of a FileWatcher + file_watcher_thread_stop_flags_.emplace(dataset_id, false); + FileWatcher watcher(config_, dataset_id, &file_watcher_thread_stop_flags_[dataset_id], + config_["storage"]["insertion_threads"].as()); + + file_watchers_.emplace(dataset_id, std::move(watcher)); + + std::thread th(&FileWatcher::run, &file_watchers_.at(dataset_id)); + file_watcher_threads_[dataset_id] = std::move(th); +} + +/* + * Stop a FileWatcher thread for the given dataset + * + * Also remove the FileWatcher thread from the map of FileWatcher threads + */ +void FileWatcherWatchdog::stop_file_watcher_thread(int64_t dataset_id) { + if (file_watcher_threads_.contains(dataset_id)) { + // Set the stop flag for the FileWatcher thread + file_watcher_thread_stop_flags_[dataset_id].store(true); + // Wait for the FileWatcher thread to stop + if (file_watcher_threads_[dataset_id].joinable()) { + file_watcher_threads_[dataset_id].join(); + } + auto file_watcher_thread_it = file_watcher_threads_.find(dataset_id); + if (file_watcher_thread_it == file_watcher_threads_.end()) { + SPDLOG_ERROR("FileWatcher thread for dataset {} not found", dataset_id); + } else { + file_watcher_threads_.erase(file_watcher_thread_it); + } + + auto file_watcher_dataset_retries_it = file_watcher_dataset_retries_.find(dataset_id); + if (file_watcher_dataset_retries_it == file_watcher_dataset_retries_.end()) { + SPDLOG_ERROR("FileWatcher thread retries for dataset {} not found", dataset_id); + } else { + file_watcher_dataset_retries_.erase(file_watcher_dataset_retries_it); + } + + auto file_watcher_thread_stop_flags_it = file_watcher_thread_stop_flags_.find(dataset_id); + if (file_watcher_thread_stop_flags_it == file_watcher_thread_stop_flags_.end()) { + SPDLOG_ERROR("FileWatcher thread stop flag for dataset {} not found", dataset_id); + } else { + file_watcher_thread_stop_flags_.erase(file_watcher_thread_stop_flags_it); + } + + auto file_watcher_it = file_watchers_.find(dataset_id); + if (file_watcher_it == file_watchers_.end()) { + SPDLOG_ERROR("FileWatcher object for dataset {} not found", dataset_id); + } else { + file_watchers_.erase(file_watcher_it); + } + + } else { + SPDLOG_ERROR("FileWatcher thread for dataset {} not found", dataset_id); + } +} + +void FileWatcherWatchdog::stop_and_clear_all_file_watcher_threads() { + for (auto& file_watcher_thread_flag : file_watcher_thread_stop_flags_) { + file_watcher_thread_flag.second.store(true); + } + for (auto& file_watcher_thread : file_watcher_threads_) { + if (file_watcher_thread.second.joinable()) { + file_watcher_thread.second.join(); + } + } + file_watcher_threads_.clear(); + file_watcher_dataset_retries_.clear(); + file_watcher_thread_stop_flags_.clear(); +} + +/* + * Watch the FileWatcher threads and start/stop them as needed + */ +void FileWatcherWatchdog::watch_file_watcher_threads() { + soci::session session = storage_database_connection_.get_session(); + + int64_t number_of_datasets = 0; + session << "SELECT COUNT(dataset_id) FROM datasets", soci::into(number_of_datasets); + + if (number_of_datasets == 0) { + if (file_watcher_threads_.empty()) { + // There are no FileWatcher threads running, nothing to do + return; + } + // There are no datasets in the database, stop all FileWatcher threads + stop_and_clear_all_file_watcher_threads(); + return; + } + + std::vector dataset_ids_vector(number_of_datasets); + session << "SELECT dataset_id FROM datasets", soci::into(dataset_ids_vector); + session.close(); + + const std::unordered_set dataset_ids(dataset_ids_vector.begin(), dataset_ids_vector.end()); + + const std::vector running_file_watcher_threads = get_running_file_watcher_threads(); + for (const auto& dataset_id : running_file_watcher_threads) { + if (!dataset_ids.contains(dataset_id)) { + // There is a FileWatcher thread running for a dataset that was deleted + // from the database. Stop the thread. + stop_file_watcher_thread(dataset_id); + } + } + + for (const auto& dataset_id : dataset_ids) { + if (file_watcher_dataset_retries_[dataset_id] > 2) { + if (file_watcher_dataset_retries_[dataset_id] == 3) { + SPDLOG_ERROR("FileWatcher thread for dataset {} failed to start 3 times, not trying again", dataset_id); + file_watcher_dataset_retries_[dataset_id] += 1; + } + // There have been more than 3 restart attempts for this dataset, we are not going to try again + } else if (!file_watcher_threads_.contains(dataset_id)) { + // There is no FileWatcher thread registered for this dataset. Start one. + if (!file_watcher_dataset_retries_.contains(dataset_id)) { + file_watcher_dataset_retries_[dataset_id] = 0; + } + start_file_watcher_thread(dataset_id); + } else if (!file_watcher_threads_[dataset_id].joinable()) { + // The FileWatcher thread is not running. (Re)start it. + start_file_watcher_thread(dataset_id); + file_watcher_dataset_retries_[dataset_id] += 1; + } + } +} + +void FileWatcherWatchdog::run() { + while (true) { + if (stop_file_watcher_watchdog_->load()) { + SPDLOG_INFO("FileWatcherWatchdog exiting run loop."); + break; + } + try { + watch_file_watcher_threads(); + } catch (const std::exception& e) { + SPDLOG_ERROR("Exception in FileWatcherWatchdog::run(): {}", e.what()); + stop(); + } + std::this_thread::sleep_for(std::chrono::seconds(file_watcher_watchdog_sleep_time_s_)); + } + + for (auto& file_watcher_thread_flag : file_watcher_thread_stop_flags_) { + file_watcher_thread_flag.second.store(true); + } + SPDLOG_INFO("FileWatcherWatchdog joining file watcher threads."); + for (auto& file_watcher_thread : file_watcher_threads_) { + if (file_watcher_thread.second.joinable()) { + file_watcher_thread.second.join(); + } + } + SPDLOG_INFO("FileWatcherWatchdog joined file watcher threads."); +} + +std::vector FileWatcherWatchdog::get_running_file_watcher_threads() { + std::vector running_file_watcher_threads = {}; + for (const auto& pair : file_watcher_threads_) { + if (pair.second.joinable()) { + running_file_watcher_threads.push_back(pair.first); + } + } + return running_file_watcher_threads; +} \ No newline at end of file diff --git a/modyn/storage/src/internal/file_wrapper/binary_file_wrapper.cpp b/modyn/storage/src/internal/file_wrapper/binary_file_wrapper.cpp new file mode 100644 index 000000000..22619d3a5 --- /dev/null +++ b/modyn/storage/src/internal/file_wrapper/binary_file_wrapper.cpp @@ -0,0 +1,165 @@ +#include "internal/file_wrapper/binary_file_wrapper.hpp" + +#include +#include +#include + +using namespace modyn::storage; + +int64_t BinaryFileWrapper::int_from_bytes(const unsigned char* begin, const unsigned char* end) { + int64_t value = 0; + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + value = std::accumulate(begin, end, 0LL, [](uint64_t acc, unsigned char byte) { return (acc << 8u) | byte; }); +#elif __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + const std::reverse_iterator rbegin(end); + const std::reverse_iterator rend(begin); + value = std::accumulate(rbegin, rend, 0LL, [](uint64_t acc, unsigned char byte) { return (acc << 8u) | byte; }); +#else +#error "Unknown byte order" +#endif + return value; +} + +uint64_t BinaryFileWrapper::get_number_of_samples() { return file_size_ / record_size_; } + +void BinaryFileWrapper::validate_file_extension() { + const std::string extension = file_path_.substr(file_path_.find_last_of('.') + 1); + if (extension != "bin") { + SPDLOG_ERROR("Binary file wrapper only supports .bin files."); + } +} + +/* + * Offset calculation to retrieve the label of a sample. + */ +int64_t BinaryFileWrapper::get_label(uint64_t index) { + ASSERT(index < get_number_of_samples(), "Invalid index"); + + const uint64_t label_start = index * record_size_; + + get_stream()->seekg(static_cast(label_start), std::ios::beg); + + std::vector label_vec(label_size_); + get_stream()->read(reinterpret_cast(label_vec.data()), static_cast(label_size_)); + + return int_from_bytes(label_vec.data(), label_vec.data() + label_size_); +} + +std::ifstream* BinaryFileWrapper::get_stream() { + if (!stream_->is_open()) { + stream_ = filesystem_wrapper_->get_stream(file_path_); + } + return stream_.get(); +} + +/* + * Offset calculation to retrieve all the labels of a sample. + */ +std::vector BinaryFileWrapper::get_all_labels() { + const uint64_t num_samples = get_number_of_samples(); + std::vector labels = std::vector(); + labels.reserve(num_samples); + + for (uint64_t i = 0; i < num_samples; ++i) { + get_stream()->seekg(static_cast(i * record_size_), std::ios::beg); + + std::vector label_vec(label_size_); + get_stream()->read(reinterpret_cast(label_vec.data()), static_cast(label_size_)); + + labels.push_back(int_from_bytes(label_vec.data(), label_vec.data() + label_size_)); + } + + return labels; +} + +/* + * Offset calculation to retrieve the data of a sample interval. + */ +std::vector> BinaryFileWrapper::get_samples(uint64_t start, uint64_t end) { + ASSERT(end >= start && end <= get_number_of_samples(), "Invalid indices"); + + const uint64_t num_samples = end - start + 1; + + std::vector> samples(num_samples); + uint64_t record_start; + for (uint64_t index = 0; index < num_samples; ++index) { + record_start = (start + index) * record_size_; + get_stream()->seekg(static_cast(record_start + label_size_), std::ios::beg); + + std::vector sample_vec(sample_size_); + get_stream()->read(reinterpret_cast(sample_vec.data()), static_cast(sample_size_)); + + samples[index] = sample_vec; + } + + return samples; +} + +/* + * Offset calculation to retrieve the data of a sample. + */ +std::vector BinaryFileWrapper::get_sample(uint64_t index) { + ASSERT(index < get_number_of_samples(), "Invalid index"); + + const uint64_t record_start = index * record_size_; + + get_stream()->seekg(static_cast(record_start + label_size_), std::ios::beg); + + std::vector sample_vec(sample_size_); + get_stream()->read(reinterpret_cast(sample_vec.data()), static_cast(sample_size_)); + + return sample_vec; +} + +/* + * Offset calculation to retrieve the data of a sample interval. + */ +std::vector> BinaryFileWrapper::get_samples_from_indices( + const std::vector& indices) { + ASSERT(std::all_of(indices.begin(), indices.end(), [&](uint64_t index) { return index < get_number_of_samples(); }), + "Invalid indices"); + + std::vector> samples; + samples.reserve(indices.size()); + + uint64_t record_start = 0; + for (const uint64_t index : indices) { + record_start = index * record_size_; + + get_stream()->seekg(static_cast(record_start + label_size_), std::ios::beg); + + std::vector sample_vec(sample_size_); + get_stream()->read(reinterpret_cast(sample_vec.data()), static_cast(sample_size_)); + + samples.push_back(sample_vec); + } + + return samples; +} + +/* + * Delete the samples at the given index list. The indices are zero based. + * + * We do not support deleting samples from binary files. + * We can only delete the entire file which is done when every sample is deleted. + * This is done to avoid the overhead of updating the file after every deletion. + * + * See DeleteData in the storage grpc servicer for more details. + */ +void BinaryFileWrapper::delete_samples(const std::vector& /*indices*/) {} + +/* + * Set the file path of the file wrapper. + */ +void BinaryFileWrapper::set_file_path(const std::string& path) { + file_path_ = path; + file_size_ = filesystem_wrapper_->get_file_size(path); + ASSERT(file_size_ % record_size_ == 0, "File size must be a multiple of the record size."); + + if (stream_->is_open()) { + stream_->close(); + } +} + +FileWrapperType BinaryFileWrapper::get_type() { return FileWrapperType::BINARY; } \ No newline at end of file diff --git a/modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp b/modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp new file mode 100644 index 000000000..dfedddc04 --- /dev/null +++ b/modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp @@ -0,0 +1,110 @@ +#include "internal/file_wrapper/csv_file_wrapper.hpp" + +#include + +#include +#include +#include + +using namespace modyn::storage; + +void CsvFileWrapper::validate_file_extension() { + if (file_path_.substr(file_path_.find_last_of('.') + 1) != "csv") { + FAIL("The file extension must be .csv"); + } +} + +std::vector CsvFileWrapper::get_sample(uint64_t index) { + ASSERT(index < get_number_of_samples(), "Invalid index"); + + std::vector row = doc_.GetRow(index); + row.erase(row.begin() + static_cast(label_index_)); + std::string row_string; + for (const auto& cell : row) { + row_string += cell + separator_; + } + row_string.pop_back(); + return {row_string.begin(), row_string.end()}; +} + +std::vector> CsvFileWrapper::get_samples(uint64_t start, uint64_t end) { + ASSERT(end >= start && end <= get_number_of_samples(), "Invalid indices"); + + std::vector> samples; + const uint64_t start_t = start; + const uint64_t end_t = end; + for (uint64_t i = start_t; i < end_t; ++i) { + std::vector row = doc_.GetRow(static_cast(i)); + row.erase(row.begin() + static_cast(label_index_)); + std::string row_string; + for (const auto& cell : row) { + row_string += cell + separator_; + } + row_string.pop_back(); + samples.emplace_back(row_string.begin(), row_string.end()); + } + + return samples; +} + +std::vector> CsvFileWrapper::get_samples_from_indices(const std::vector& indices) { + ASSERT(std::all_of(indices.begin(), indices.end(), [&](uint64_t index) { return index < get_number_of_samples(); }), + "Invalid indices"); + + std::vector> samples; + for (const uint64_t index : indices) { + std::vector row = doc_.GetRow(index); + row.erase(row.begin() + static_cast(label_index_)); + std::string row_string; + for (const auto& cell : row) { + row_string += cell + separator_; + } + row_string.pop_back(); + samples.emplace_back(row_string.begin(), row_string.end()); + } + return samples; +} + +int64_t CsvFileWrapper::get_label(uint64_t index) { + ASSERT(index < get_number_of_samples(), "Invalid index"); + return doc_.GetCell(static_cast(label_index_), static_cast(index)); +} + +std::vector CsvFileWrapper::get_all_labels() { + std::vector labels; + const uint64_t num_samples = get_number_of_samples(); + for (uint64_t i = 0; i < num_samples; i++) { + labels.push_back(get_label(i)); + } + return labels; +} + +uint64_t CsvFileWrapper::get_number_of_samples() { return static_cast(doc_.GetRowCount()); } + +void CsvFileWrapper::delete_samples(const std::vector& indices) { + ASSERT(std::all_of(indices.begin(), indices.end(), [&](uint64_t index) { return index < get_number_of_samples(); }), + "Invalid indices"); + + std::vector indices_copy = indices; + std::sort(indices_copy.begin(), indices_copy.end(), std::greater<>()); + + for (const size_t index : indices_copy) { + doc_.RemoveRow(index); + } + + doc_.Save(file_path_); +} + +void CsvFileWrapper::set_file_path(const std::string& path) { + file_path_ = path; + + if (stream_->is_open()) { + stream_->close(); + } + + stream_ = filesystem_wrapper_->get_stream(path); + + doc_ = rapidcsv::Document(*stream_, label_params_, rapidcsv::SeparatorParams(separator_)); +} + +FileWrapperType CsvFileWrapper::get_type() { return FileWrapperType::CSV; } diff --git a/modyn/storage/src/internal/file_wrapper/file_wrapper_utils.cpp b/modyn/storage/src/internal/file_wrapper/file_wrapper_utils.cpp new file mode 100644 index 000000000..75f41c42d --- /dev/null +++ b/modyn/storage/src/internal/file_wrapper/file_wrapper_utils.cpp @@ -0,0 +1,33 @@ +#include "internal/file_wrapper/file_wrapper_utils.hpp" + +#include +#include + +#include "internal/file_wrapper/file_wrapper.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" + +namespace modyn::storage { + +std::unique_ptr get_file_wrapper(const std::string& path, const FileWrapperType& type, + const YAML::Node& file_wrapper_config, + const std::shared_ptr& filesystem_wrapper) { + ASSERT(filesystem_wrapper != nullptr, "Filesystem wrapper is nullptr"); + ASSERT(!path.empty(), "Path is empty"); + ASSERT(filesystem_wrapper->exists(path), fmt::format("Path {} does not exist", path)); + + std::unique_ptr file_wrapper; + if (type == FileWrapperType::BINARY) { + file_wrapper = std::make_unique(path, file_wrapper_config, filesystem_wrapper); + } else if (type == FileWrapperType::SINGLE_SAMPLE) { + file_wrapper = std::make_unique(path, file_wrapper_config, filesystem_wrapper); + } else if (type == FileWrapperType::CSV) { + file_wrapper = std::make_unique(path, file_wrapper_config, filesystem_wrapper); + } else if (type == FileWrapperType::INVALID_FW) { + FAIL(fmt::format("Trying to instantiate INVALID FileWrapper at path {}", path)); + } else { + FAIL(fmt::format("Unknown file wrapper type {}", static_cast(type))); + } + return file_wrapper; +} + +} // namespace modyn::storage \ No newline at end of file diff --git a/modyn/storage/src/internal/file_wrapper/single_sample_file_wrapper.cpp b/modyn/storage/src/internal/file_wrapper/single_sample_file_wrapper.cpp new file mode 100644 index 000000000..e9cce7dca --- /dev/null +++ b/modyn/storage/src/internal/file_wrapper/single_sample_file_wrapper.cpp @@ -0,0 +1,69 @@ +#include "internal/file_wrapper/single_sample_file_wrapper.hpp" + +#include + +#include +#include + +#include "modyn/utils/utils.hpp" + +using namespace modyn::storage; + +uint64_t SingleSampleFileWrapper::get_number_of_samples() { + ASSERT(file_wrapper_config_["file_extension"], "File wrapper configuration does not contain a file extension"); + const auto file_extension = file_wrapper_config_["file_extension"].as(); + + if (file_path_.find(file_extension) == std::string::npos) { + return 0; + } + return 1; +} + +int64_t SingleSampleFileWrapper::get_label(uint64_t /* index */) { + ASSERT(file_wrapper_config_["file_extension"], "File wrapper configuration does not contain a label file extension"); + const auto label_file_extension = file_wrapper_config_["label_file_extension"].as(); + auto label_path = std::filesystem::path(file_path_).replace_extension(label_file_extension); + + ASSERT(filesystem_wrapper_->exists(label_path), fmt::format("Label file does not exist: {}", label_path.string())); + std::vector label = filesystem_wrapper_->get(label_path); + + if (!label.empty()) { + auto label_str = std::string(reinterpret_cast(label.data()), label.size()); + return std::stoi(label_str); + } + + FAIL(fmt::format("Label file is empty: {}", label_path.string())); + return -1; +} + +std::vector SingleSampleFileWrapper::get_all_labels() { return std::vector{get_label(0)}; } + +std::vector SingleSampleFileWrapper::get_sample(uint64_t index) { + ASSERT(index == 0, "Single sample file wrappers can only access the first sample"); + return filesystem_wrapper_->get(file_path_); +} + +std::vector> SingleSampleFileWrapper::get_samples(uint64_t start, uint64_t end) { + ASSERT(start == 0 && end == 1, "Single sample file wrappers can only access the first sample"); + return std::vector>{get_sample(0)}; +} + +std::vector> SingleSampleFileWrapper::get_samples_from_indices( + const std::vector& indices) { + ASSERT(indices.size() == 1 && indices[0] == 0, "Single sample file wrappers can only access the first sample"); + return std::vector>{get_sample(0)}; +} + +void SingleSampleFileWrapper::validate_file_extension() { + ASSERT(file_wrapper_config_["file_extension"], "File wrapper configuration does not contain a file extension"); + + const auto file_extension = file_wrapper_config_["file_extension"].as(); + if (file_path_.find(file_extension) == std::string::npos) { + FAIL(fmt::format("File extension {} does not match file path {}", file_extension, file_path_)); + } +} + +void SingleSampleFileWrapper::delete_samples(const std::vector& /* indices */) { +} // The file will be deleted at a higher level + +FileWrapperType SingleSampleFileWrapper::get_type() { return FileWrapperType::SINGLE_SAMPLE; } diff --git a/modyn/storage/src/internal/filesystem_wrapper/filesystem_wrapper_utils.cpp b/modyn/storage/src/internal/filesystem_wrapper/filesystem_wrapper_utils.cpp new file mode 100644 index 000000000..8a9baedf0 --- /dev/null +++ b/modyn/storage/src/internal/filesystem_wrapper/filesystem_wrapper_utils.cpp @@ -0,0 +1,23 @@ +#include "internal/filesystem_wrapper/filesystem_wrapper_utils.hpp" + +#include + +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" +#include "internal/filesystem_wrapper/local_filesystem_wrapper.hpp" +#include "modyn/utils/utils.hpp" + +namespace modyn::storage { + +std::shared_ptr get_filesystem_wrapper(const FilesystemWrapperType& type) { + std::shared_ptr filesystem_wrapper; + if (type == FilesystemWrapperType::LOCAL) { + filesystem_wrapper = std::make_shared(); + } else if (type == FilesystemWrapperType::INVALID_FSW) { + FAIL("Trying to instantiate INVALID FileSystemWrapper"); + } else { + FAIL("Unknown filesystem wrapper type"); + } + return filesystem_wrapper; +} + +} // namespace modyn::storage \ No newline at end of file diff --git a/modyn/storage/src/internal/filesystem_wrapper/local_filesystem_wrapper.cpp b/modyn/storage/src/internal/filesystem_wrapper/local_filesystem_wrapper.cpp new file mode 100644 index 000000000..f919f57e3 --- /dev/null +++ b/modyn/storage/src/internal/filesystem_wrapper/local_filesystem_wrapper.cpp @@ -0,0 +1,108 @@ +#include "internal/filesystem_wrapper/local_filesystem_wrapper.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "modyn/utils/utils.hpp" + +using namespace modyn::storage; + +std::vector LocalFilesystemWrapper::get(const std::string& path) { + std::ifstream file; + file.open(path, std::ios::binary); + std::vector buffer(std::istreambuf_iterator(file), {}); + file.close(); + return buffer; +} + +std::shared_ptr LocalFilesystemWrapper::get_stream(const std::string& path) { + std::shared_ptr file = std::make_shared(); + file->open(path, std::ios::binary); + return file; +} + +bool LocalFilesystemWrapper::exists(const std::string& path) { return std::filesystem::exists(path); } + +std::vector LocalFilesystemWrapper::list(const std::string& path, bool recursive, std::string extension) { + std::vector paths; + + if (!std::filesystem::exists(path)) { + return paths; + } + + if (recursive) { + for (const auto& entry : std::filesystem::recursive_directory_iterator(path)) { + const std::filesystem::path& entry_path = entry.path(); + if (!std::filesystem::is_directory(entry) && entry_path.extension().string() == extension) { + paths.push_back(entry_path); + } + } + } else { + for (const auto& entry : std::filesystem::directory_iterator(path)) { + const std::filesystem::path& entry_path = entry.path(); + if (!std::filesystem::is_directory(entry) && entry_path.extension().string() == extension) { + paths.push_back(entry_path); + } + } + } + + return paths; +} + +bool LocalFilesystemWrapper::is_directory(const std::string& path) { return std::filesystem::is_directory(path); } + +bool LocalFilesystemWrapper::is_file(const std::string& path) { return std::filesystem::is_regular_file(path); } + +uint64_t LocalFilesystemWrapper::get_file_size(const std::string& path) { + return static_cast(std::filesystem::file_size(path)); +} + +template +std::time_t to_time_t(TP tp) { + using namespace std::chrono; + auto sctp = time_point_cast(tp - TP::clock::now() + system_clock::now()); + return system_clock::to_time_t(sctp); +} + +int64_t LocalFilesystemWrapper::get_modified_time(const std::string& path) { + ASSERT(is_valid_path(path), fmt::format("Invalid path: {}", path)); + ASSERT(exists(path), fmt::format("Path does not exist: {}", path)); + static_assert(sizeof(int64_t) >= sizeof(std::time_t), "Cannot cast time_t to int64_t"); + + const auto modified_time = std::filesystem::last_write_time(path); + const auto cftime = to_time_t(modified_time); + return static_cast(cftime); + + /* C++20 version, not supported by compilers yet */ + /* + const auto modified_time = std::filesystem::last_write_time(path); + const auto system_time = std::chrono::clock_cast(modified_time); + const std::time_t time = std::chrono::system_clock::to_time_t(system_time); + return static_cast(time); */ +} + +bool LocalFilesystemWrapper::is_valid_path(const std::string& path) { return std::filesystem::exists(path); } + +bool LocalFilesystemWrapper::remove(const std::string& path) { + ASSERT(!std::filesystem::is_directory(path), fmt::format("Path is a directory: {}", path)); + + if (!std::filesystem::exists(path)) { + SPDLOG_WARN(fmt::format("Trying to delete already deleted file {}", path)); + return true; + } + + SPDLOG_DEBUG("Removing file: {}", path); + + return std::filesystem::remove(path); +} + +FilesystemWrapperType LocalFilesystemWrapper::get_type() { return FilesystemWrapperType::LOCAL; } diff --git a/modyn/storage/src/internal/grpc/storage_grpc_server.cpp b/modyn/storage/src/internal/grpc/storage_grpc_server.cpp new file mode 100644 index 000000000..8966d649a --- /dev/null +++ b/modyn/storage/src/internal/grpc/storage_grpc_server.cpp @@ -0,0 +1,57 @@ +#include "internal/grpc/storage_grpc_server.hpp" + +#include + +#include "internal/grpc/storage_service_impl.hpp" + +using namespace modyn::storage; + +void StorageGrpcServer::run() { + if (!config_["storage"]["port"]) { + SPDLOG_ERROR("No port specified in config.yaml"); + return; + } + auto port = config_["storage"]["port"].as(); + std::string server_address = fmt::format("[::]:{}", port); + if (!config_["storage"]["retrieval_threads"]) { + SPDLOG_ERROR("No retrieval_threads specified in config.yaml"); + return; + } + auto retrieval_threads = config_["storage"]["retrieval_threads"].as(); + StorageServiceImpl service(config_, retrieval_threads); + + EnableDefaultHealthCheckService(true); + reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + grpc::ResourceQuota quota; + std::uint64_t num_cores = std::thread::hardware_concurrency(); + if (num_cores == 0) { + SPDLOG_WARN("Could not get number of cores, assuming 64."); + num_cores = 64; + } + // Note that in C++, everything is a thread in gRPC, but we want to keep the same logic as in Python + // However, we increase the threadpool a bit compared to Python + const std::uint64_t num_processes = + std::max(static_cast(4), std::min(static_cast(64), num_cores)); + const std::uint64_t num_threads_per_process = std::max(static_cast(8), num_processes / 4); + const int max_threads = static_cast(num_processes * num_threads_per_process); + SPDLOG_INFO("Using {} gRPC threads.", max_threads); + quota.SetMaxThreads(max_threads); + builder.SetResourceQuota(quota); + builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIME_MS, 2 * 60 * 60 * 1000); + builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + builder.SetMaxReceiveMessageSize(1024 * 1024 * 128); + builder.SetMaxSendMessageSize(1024 * 1024 * 128); + + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + + auto server = builder.BuildAndStart(); + SPDLOG_INFO("Server listening on {}", server_address); + + // Wait for the server to shutdown or signal to shutdown. + stop_grpc_server_->wait(false); + server->Shutdown(); + + stop(); +} \ No newline at end of file diff --git a/modyn/storage/src/internal/grpc/storage_service_impl.cpp b/modyn/storage/src/internal/grpc/storage_service_impl.cpp new file mode 100644 index 000000000..26d22c1d2 --- /dev/null +++ b/modyn/storage/src/internal/grpc/storage_service_impl.cpp @@ -0,0 +1,603 @@ +#include "internal/grpc/storage_service_impl.hpp" + +#include + +#include "internal/database/cursor_handler.hpp" +#include "internal/database/storage_database_connection.hpp" +#include "internal/file_wrapper/file_wrapper_utils.hpp" +#include "internal/filesystem_wrapper/filesystem_wrapper_utils.hpp" +#include "modyn/utils/utils.hpp" + +using namespace modyn::storage; + +// ------- StorageServiceImpl ------- + +Status StorageServiceImpl::Get( // NOLINT readability-identifier-naming + ServerContext* context, const modyn::storage::GetRequest* request, + ServerWriter* writer) { + return Get_Impl>(context, request, writer); +} + +Status StorageServiceImpl::GetNewDataSince( // NOLINT readability-identifier-naming + ServerContext* context, const modyn::storage::GetNewDataSinceRequest* request, + ServerWriter* writer) { + return GetNewDataSince_Impl>(context, request, writer); +} + +Status StorageServiceImpl::GetDataInInterval( // NOLINT readability-identifier-naming + ServerContext* context, const modyn::storage::GetDataInIntervalRequest* request, + ServerWriter* writer) { + return GetDataInInterval_Impl>(context, request, writer); +} + +Status StorageServiceImpl::CheckAvailability( // NOLINT readability-identifier-naming + ServerContext* /*context*/, const modyn::storage::DatasetAvailableRequest* request, + modyn::storage::DatasetAvailableResponse* response) { + try { + soci::session session = storage_database_connection_.get_session(); + + // Check if the dataset exists + const int64_t dataset_id = get_dataset_id(session, request->dataset_id()); + session.close(); + SPDLOG_INFO(fmt::format("Received availability request for dataset {}", dataset_id)); + + if (dataset_id == -1) { + response->set_available(false); + return {StatusCode::OK, "Dataset does not exist."}; + } + response->set_available(true); + return {StatusCode::OK, "Dataset exists."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in CheckAvailability: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in CheckAvailability: {}", e.what())}; + } +} + +Status StorageServiceImpl::RegisterNewDataset( // NOLINT readability-identifier-naming + ServerContext* /*context*/, const modyn::storage::RegisterNewDatasetRequest* request, + modyn::storage::RegisterNewDatasetResponse* response) { + try { + SPDLOG_INFO(fmt::format("Received register new dataset request for {} at {}.", request->dataset_id(), + request->base_path())); + const bool success = storage_database_connection_.add_dataset( + request->dataset_id(), request->base_path(), + FilesystemWrapper::get_filesystem_wrapper_type(request->filesystem_wrapper_type()), + FileWrapper::get_file_wrapper_type(request->file_wrapper_type()), request->description(), request->version(), + request->file_wrapper_config(), request->ignore_last_timestamp(), + static_cast(request->file_watcher_interval())); + response->set_success(success); + return Status::OK; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in RegisterNewDataset: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in RegisterNewDataset: {}", e.what())}; + } +} + +Status StorageServiceImpl::GetCurrentTimestamp( // NOLINT readability-identifier-naming + ServerContext* /*context*/, const modyn::storage::GetCurrentTimestampRequest* /*request*/, + modyn::storage::GetCurrentTimestampResponse* response) { + try { + SPDLOG_INFO("ReceivedGetCurrentTimestamp request."); + response->set_timestamp( + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()) + .count()); + return {StatusCode::OK, "Timestamp retrieved."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in GetCurrentTimestamp: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in GetCurrentTimestamp: {}", e.what())}; + } +} + +Status StorageServiceImpl::DeleteDataset( // NOLINT readability-identifier-naming + ServerContext* /*context*/, const modyn::storage::DatasetAvailableRequest* request, + modyn::storage::DeleteDatasetResponse* response) { + try { + response->set_success(false); + int64_t filesystem_wrapper_type; + + soci::session session = storage_database_connection_.get_session(); + int64_t dataset_id = get_dataset_id(session, request->dataset_id()); + SPDLOG_INFO(fmt::format("Received DeleteDataset Request for dataset {}", dataset_id)); + if (dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", request->dataset_id()); + return {StatusCode::OK, "Dataset does not exist."}; + } + session << "SELECT filesystem_wrapper_type FROM datasets WHERE name = :name", soci::into(filesystem_wrapper_type), + soci::use(request->dataset_id()); + + auto filesystem_wrapper = get_filesystem_wrapper(static_cast(filesystem_wrapper_type)); + + int64_t number_of_files = 0; + session << "SELECT COUNT(file_id) FROM files WHERE dataset_id = :dataset_id", soci::into(number_of_files), + soci::use(dataset_id); + + if (number_of_files > 0) { + std::vector file_paths(number_of_files + 1); + session << "SELECT path FROM files WHERE dataset_id = :dataset_id", soci::into(file_paths), soci::use(dataset_id); + try { + for (const auto& file_path : file_paths) { + filesystem_wrapper->remove(file_path); + } + } catch (const modyn::utils::ModynException& e) { + SPDLOG_ERROR("Error deleting dataset: {}", e.what()); + return {StatusCode::OK, "Error deleting dataset."}; + } + } + session.close(); + const bool success = storage_database_connection_.delete_dataset(request->dataset_id(), dataset_id); + + response->set_success(success); + return Status::OK; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in DeleteDataset: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in DeleteDataset: {}", e.what())}; + } +} + +Status StorageServiceImpl::DeleteData( // NOLINT readability-identifier-naming + ServerContext* /*context*/, const modyn::storage::DeleteDataRequest* request, + modyn::storage::DeleteDataResponse* response) { + try { + response->set_success(false); + soci::session session = storage_database_connection_.get_session(); + + // Check if the dataset exists + int64_t dataset_id = -1; + std::string base_path; + int64_t filesystem_wrapper_type = -1; + int64_t file_wrapper_type = -1; + std::string file_wrapper_config; + session << "SELECT dataset_id, base_path, filesystem_wrapper_type, file_wrapper_type, file_wrapper_config FROM " + "datasets WHERE name = :name", + soci::into(dataset_id), soci::into(base_path), soci::into(filesystem_wrapper_type), + soci::into(file_wrapper_type), soci::into(file_wrapper_config), soci::use(request->dataset_id()); + + SPDLOG_INFO(fmt::format("Received DeleteData Request for dataset {}", dataset_id)); + + if (dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", request->dataset_id()); + return {StatusCode::OK, "Dataset does not exist."}; + } + + if (request->keys_size() == 0) { + SPDLOG_ERROR("No keys provided."); + return {StatusCode::OK, "No keys provided."}; + } + + std::vector sample_ids(request->keys_size()); + // index is int type due to gRPC typing + for (int index = 0; index < request->keys_size(); ++index) { + sample_ids[index] = request->keys(index); + } + + int64_t number_of_files = 0; + std::string sample_placeholders = fmt::format("({})", fmt::join(sample_ids, ",")); + + std::string sql = fmt::format( + "SELECT COUNT(DISTINCT file_id) FROM samples WHERE dataset_id = :dataset_id AND " + "sample_id IN {}", + sample_placeholders); + session << sql, soci::into(number_of_files), soci::use(dataset_id); + SPDLOG_INFO(fmt::format("DeleteData Request for dataset {} found {} relevant files", dataset_id, number_of_files)); + + if (number_of_files == 0) { + SPDLOG_ERROR("No samples found in dataset {}.", dataset_id); + return {StatusCode::OK, "No samples found."}; + } + + // Get the file ids + std::vector file_ids(number_of_files + 1); + sql = fmt::format("SELECT DISTINCT file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN {}", + sample_placeholders); + session << sql, soci::into(file_ids), soci::use(dataset_id); + + if (file_ids.empty()) { + SPDLOG_ERROR("No files found in dataset {}.", dataset_id); + return {StatusCode::OK, "No files found."}; + } + + auto filesystem_wrapper = get_filesystem_wrapper(static_cast(filesystem_wrapper_type)); + const YAML::Node file_wrapper_config_node = YAML::Load(file_wrapper_config); + std::string file_placeholders = fmt::format("({})", fmt::join(file_ids, ",")); + std::string index_placeholders; + + try { + std::vector file_paths(number_of_files + 1); + sql = fmt::format("SELECT path FROM files WHERE file_id IN {}", file_placeholders); + session << sql, soci::into(file_paths); + if (file_paths.size() != file_ids.size()) { + SPDLOG_ERROR("Error deleting data: Could not find all files."); + return {StatusCode::OK, "Error deleting data."}; + } + + auto file_wrapper = get_file_wrapper(file_paths.front(), static_cast(file_wrapper_type), + file_wrapper_config_node, filesystem_wrapper); + for (uint64_t i = 0; i < file_paths.size(); ++i) { + const auto& file_id = file_ids[i]; + const auto& path = file_paths[i]; + SPDLOG_INFO( + fmt::format("DeleteData Request for dataset {} handling path {} (file id {})", dataset_id, path, file_id)); + + file_wrapper->set_file_path(path); + + int64_t samples_to_delete = 0; + sql = fmt::format("SELECT COUNT(sample_id) FROM samples WHERE file_id = :file_id AND sample_id IN {}", + sample_placeholders); + session << sql, soci::into(samples_to_delete), soci::use(file_id); + + std::vector sample_ids_to_delete_ids(samples_to_delete + 1); + sql = fmt::format("SELECT sample_id FROM samples WHERE file_id = :file_id AND sample_id IN {}", + sample_placeholders); + session << sql, soci::into(sample_ids_to_delete_ids), soci::use(file_id); + + file_wrapper->delete_samples(sample_ids_to_delete_ids); + + index_placeholders = fmt::format("({})", fmt::join(sample_ids_to_delete_ids, ",")); + sql = fmt::format("DELETE FROM samples WHERE file_id = :file_id AND sample_id IN {}", index_placeholders); + session << sql, soci::use(file_id); + + int64_t number_of_samples_in_file = 0; + session << "SELECT number_of_samples FROM files WHERE file_id = :file_id", + soci::into(number_of_samples_in_file), soci::use(file_id); + + if (number_of_samples_in_file - samples_to_delete == 0) { + session << "DELETE FROM files WHERE file_id = :file_id", soci::use(file_id); + filesystem_wrapper->remove(path); + } else { + session << "UPDATE files SET number_of_samples = :number_of_samples WHERE file_id = :file_id", + soci::use(number_of_samples_in_file - samples_to_delete), soci::use(file_id); + } + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error deleting data: {}", e.what()); + return {StatusCode::OK, "Error deleting data."}; + } + session.close(); + response->set_success(true); + return {StatusCode::OK, "Data deleted."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in DeleteData: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in DeleteData: {}", e.what())}; + } +} + +Status StorageServiceImpl::GetDataPerWorker( // NOLINT readability-identifier-naming + ServerContext* /*context*/, const modyn::storage::GetDataPerWorkerRequest* request, + ServerWriter<::modyn::storage::GetDataPerWorkerResponse>* writer) { + try { + soci::session session = storage_database_connection_.get_session(); + + // Check if the dataset exists + int64_t dataset_id = get_dataset_id(session, request->dataset_id()); + + if (dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", request->dataset_id()); + return {StatusCode::OK, "Dataset does not exist."}; + } + SPDLOG_INFO( + fmt::format("Received GetDataPerWorker Request for dataset {} (id = {}) and worker {} out of {} workers", + request->dataset_id(), dataset_id, request->worker_id(), request->total_workers())); + + int64_t total_keys = 0; + session << "SELECT COALESCE(SUM(number_of_samples), 0) FROM files WHERE dataset_id = :dataset_id", + soci::into(total_keys), soci::use(dataset_id); + + if (total_keys > 0) { + int64_t start_index = 0; + int64_t limit = 0; + std::tie(start_index, limit) = + get_partition_for_worker(request->worker_id(), request->total_workers(), total_keys); + + const std::string query = + fmt::format("SELECT sample_id FROM samples WHERE dataset_id = {} ORDER BY sample_id OFFSET {} LIMIT {}", + dataset_id, start_index, limit); + const std::string cursor_name = fmt::format("pw_cursor_{}_{}", dataset_id, request->worker_id()); + CursorHandler cursor_handler(session, storage_database_connection_.get_drivername(), query, cursor_name, 1); + + std::vector records; + std::vector record_buf; + record_buf.reserve(sample_batch_size_); + + while (true) { + records = cursor_handler.yield_per(sample_batch_size_); + + if (records.empty()) { + break; + } + + const uint64_t obtained_records = records.size(); + ASSERT(static_cast(obtained_records) <= sample_batch_size_, "Received too many samples"); + + if (static_cast(records.size()) == sample_batch_size_) { + // If we obtained a full buffer, we can emit a response directly + modyn::storage::GetDataPerWorkerResponse response; + for (const auto& record : records) { + response.add_keys(record.id); + } + + writer->Write(response); + } else { + // If not, we append to our record buf + record_buf.insert(record_buf.end(), records.begin(), records.end()); + // If our record buf is big enough, emit a message + if (static_cast(records.size()) >= sample_batch_size_) { + modyn::storage::GetDataPerWorkerResponse response; + + // sample_batch_size is signed int... + for (int64_t record_idx = 0; record_idx < sample_batch_size_; ++record_idx) { + const SampleRecord& record = record_buf[record_idx]; + response.add_keys(record.id); + } + + // Now, delete first sample_batch_size elements from vector as we are sending them + record_buf.erase(record_buf.begin(), record_buf.begin() + sample_batch_size_); + + ASSERT(static_cast(record_buf.size()) < sample_batch_size_, + "The record buffer should never have more than 2*sample_batch_size elements!"); + + writer->Write(response); + } + } + } + cursor_handler.close_cursor(); + session.close(); + + if (!record_buf.empty()) { + ASSERT(static_cast(record_buf.size()) < sample_batch_size_, + "We should have written this buffer before!"); + + modyn::storage::GetDataPerWorkerResponse response; + for (const auto& record : record_buf) { + response.add_keys(record.id); + } + + writer->Write(response); + } + } + + return {StatusCode::OK, "Data retrieved."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in GetDataPerWorker: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in GetDataPerWorker: {}", e.what())}; + } +} + +Status StorageServiceImpl::GetDatasetSize( // NOLINT readability-identifier-naming + ServerContext* /*context*/, const modyn::storage::GetDatasetSizeRequest* request, + modyn::storage::GetDatasetSizeResponse* response) { + try { + soci::session session = storage_database_connection_.get_session(); + + // Check if the dataset exists + int64_t dataset_id = get_dataset_id(session, request->dataset_id()); + + if (dataset_id == -1) { + SPDLOG_ERROR("Dataset {} does not exist.", request->dataset_id()); + return {StatusCode::OK, "Dataset does not exist."}; + } + + int64_t total_keys = 0; + session << "SELECT COALESCE(SUM(number_of_samples), 0) FROM files WHERE dataset_id = :dataset_id", + soci::into(total_keys), soci::use(dataset_id); + + session.close(); + + response->set_num_keys(total_keys); + response->set_success(true); + return {StatusCode::OK, "Dataset size retrieved."}; + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in GetDatasetSize: {}", e.what()); + return {StatusCode::OK, fmt::format("Error in GetDatasetSize: {}", e.what())}; + } +} + +// ------- Helper functions ------- +std::vector::const_iterator, std::vector::const_iterator>> +StorageServiceImpl::get_keys_per_thread(const std::vector& keys, uint64_t threads) { + ASSERT(threads > 0, "This function is only intended for multi-threaded retrieval."); + + std::vector::const_iterator, std::vector::const_iterator>> keys_per_thread( + threads); + try { + if (keys.empty()) { + return keys_per_thread; + } + + auto number_of_keys = static_cast(keys.size()); + + if (number_of_keys < threads) { + threads = number_of_keys; + } + + const auto subset_size = static_cast(number_of_keys / threads); + for (uint64_t thread_id = 0; thread_id < threads; ++thread_id) { + // These need to be signed because we add them to iterators. + const auto start_index = static_cast(thread_id * subset_size); + const auto end_index = static_cast((thread_id + 1) * subset_size); + + DEBUG_ASSERT(start_index < static_cast(keys.size()), + fmt::format("Start Index too big! idx = {}, size = {}, thread_id = {}+1/{}, subset_size = {}", + start_index, keys.size(), thread_id, threads, subset_size)); + DEBUG_ASSERT(end_index <= static_cast(keys.size()), + fmt::format("End Index too big! idx = {}, size = {}, thread_id = {}+1/{}, subset_size = {}", + start_index, keys.size(), thread_id, threads, subset_size)); + + if (thread_id == threads - 1) { + keys_per_thread[thread_id] = std::make_pair(keys.begin() + start_index, keys.end()); + } else { + keys_per_thread[thread_id] = std::make_pair(keys.begin() + start_index, keys.begin() + end_index); + } + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in get_keys_per_thread with keys.size() = {}, retrieval_theads = {}: {}", keys.size(), threads, + e.what()); + throw; + } + return keys_per_thread; +} + +std::vector StorageServiceImpl::get_samples_corresponding_to_file(const int64_t file_id, + const int64_t dataset_id, + const std::vector& request_keys, + soci::session& session) { + const auto number_of_samples = static_cast(request_keys.size()); + std::vector sample_ids(number_of_samples + 1); + + try { + const std::string sample_placeholders = fmt::format("({})", fmt::join(request_keys, ",")); + + const std::string sql = fmt::format( + "SELECT sample_id FROM samples WHERE file_id = :file_id AND dataset_id = " + ":dataset_id AND sample_id IN {}", + sample_placeholders); + session << sql, soci::into(sample_ids), soci::use(file_id), soci::use(dataset_id); + } catch (const std::exception& e) { + SPDLOG_ERROR( + "Error in get_samples_corresponding_to_file with file_id = {}, dataset_id = {}, number_of_samples = {}: {}", + file_id, dataset_id, number_of_samples, e.what()); + throw; + } + return sample_ids; +} + +std::vector StorageServiceImpl::get_file_ids_for_samples(const std::vector& request_keys, + const int64_t dataset_id, soci::session& session) { + const auto number_of_samples = static_cast(request_keys.size()); + const std::string sample_placeholders = fmt::format("({})", fmt::join(request_keys, ",")); + + const std::string sql = fmt::format( + "SELECT DISTINCT file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN {}", sample_placeholders); + std::vector file_ids(number_of_samples + 1); + session << sql, soci::into(file_ids), soci::use(dataset_id); + + return file_ids; +} + +int64_t StorageServiceImpl::get_number_of_samples_in_file(int64_t file_id, soci::session& session, + const int64_t dataset_id) { + int64_t number_of_samples = 0; + session << "SELECT number_of_samples FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", + soci::into(number_of_samples), soci::use(file_id), soci::use(dataset_id); + return number_of_samples; +} + +std::tuple StorageServiceImpl::get_partition_for_worker(const int64_t worker_id, + const int64_t total_workers, + const int64_t total_num_elements) { + if (worker_id < 0 || worker_id >= total_workers) { + FAIL("Worker id must be between 0 and total_workers - 1."); + } + + const int64_t subset_size = total_num_elements / total_workers; + int64_t worker_subset_size = subset_size; + + const int64_t threshold = total_num_elements % total_workers; + if (threshold > 0) { + if (worker_id < threshold) { + worker_subset_size += 1; + const int64_t start_index = worker_id * (subset_size + 1); + return {start_index, worker_subset_size}; + } + const int64_t start_index = threshold * (subset_size + 1) + (worker_id - threshold) * subset_size; + return {start_index, worker_subset_size}; + } + const int64_t start_index = worker_id * subset_size; + return {start_index, worker_subset_size}; +} + +int64_t StorageServiceImpl::get_dataset_id(soci::session& session, const std::string& dataset_name) { + int64_t dataset_id = -1; + session << "SELECT dataset_id FROM datasets WHERE name = :name", soci::into(dataset_id), soci::use(dataset_name); + + return dataset_id; +} + +std::vector StorageServiceImpl::get_file_ids(soci::session& session, const int64_t dataset_id, + const int64_t start_timestamp, const int64_t end_timestamp) { + // TODO(#362): We are almost excecuting the same query twice since we first count and then get the data + + const uint64_t number_of_files = get_file_count(session, dataset_id, start_timestamp, end_timestamp); + + if (number_of_files == 0) { + return {}; + } + + return get_file_ids_given_number_of_files(session, dataset_id, start_timestamp, end_timestamp, number_of_files); +} + +uint64_t StorageServiceImpl::get_file_count(soci::session& session, const int64_t dataset_id, + const int64_t start_timestamp, const int64_t end_timestamp) { + uint64_t number_of_files = -1; + try { + if (start_timestamp >= 0 && end_timestamp == -1) { + session << "SELECT COUNT(*) FROM files WHERE dataset_id = :dataset_id AND updated_at >= :start_timestamp", + soci::into(number_of_files), soci::use(dataset_id), soci::use(start_timestamp); + } else if (start_timestamp == -1 && end_timestamp >= 0) { + session << "SELECT COUNT(*) FROM files WHERE dataset_id = :dataset_id AND updated_at <= :end_timestamp", + soci::into(number_of_files), soci::use(dataset_id), soci::use(end_timestamp); + } else if (start_timestamp >= 0 && end_timestamp >= 0) { + session << "SELECT COUNT(*) FROM files WHERE dataset_id = :dataset_id AND updated_at >= :start_timestamp AND " + "updated_at <= :end_timestamp", + soci::into(number_of_files), soci::use(dataset_id), soci::use(start_timestamp), soci::use(end_timestamp); + } else { + session << "SELECT COUNT(*) FROM files WHERE dataset_id = :dataset_id", soci::into(number_of_files), + soci::use(dataset_id); + } + } catch (const std::exception& e) { + SPDLOG_ERROR("Error in get_file_count with dataset_id = {}, start_timestamp = {}, end_timestamp = {}: {}", + dataset_id, start_timestamp, end_timestamp, e.what()); + throw; + } + return number_of_files; +} + +std::vector StorageServiceImpl::get_file_ids_given_number_of_files(soci::session& session, + const int64_t dataset_id, + const int64_t start_timestamp, + const int64_t end_timestamp, + const uint64_t number_of_files) { + std::vector file_ids(number_of_files + 1); + + try { + if (start_timestamp >= 0 && end_timestamp == -1) { + session << "SELECT file_id FROM files WHERE dataset_id = :dataset_id AND updated_at >= :start_timestamp ORDER BY " + "updated_at ASC", + soci::into(file_ids), soci::use(dataset_id), soci::use(start_timestamp); + } else if (start_timestamp == -1 && end_timestamp >= 0) { + session << "SELECT file_id FROM files WHERE dataset_id = :dataset_id AND updated_at <= :end_timestamp ORDER BY " + "updated_at ASC", + soci::into(file_ids), soci::use(dataset_id), soci::use(end_timestamp); + } else if (start_timestamp >= 0 && end_timestamp >= 0) { + session << "SELECT file_id FROM files WHERE dataset_id = :dataset_id AND updated_at >= :start_timestamp AND " + "updated_at <= :end_timestamp ORDER BY updated_at ASC", + soci::into(file_ids), soci::use(dataset_id), soci::use(start_timestamp), soci::use(end_timestamp); + } else { + session << "SELECT file_id FROM files WHERE dataset_id = :dataset_id ORDER BY updated_at ASC", + soci::into(file_ids), soci::use(dataset_id); + } + } catch (const std::exception& e) { + SPDLOG_ERROR( + "Error in get_file_ids_given_number_of_files with dataset_id = {}, start_timestamp = {}, end_timestamp = {}, " + "number_of_files = {}: {}", + dataset_id, start_timestamp, end_timestamp, number_of_files, e.what()); + throw; + } + return file_ids; +} + +DatasetData StorageServiceImpl::get_dataset_data(soci::session& session, std::string& dataset_name) { + int64_t dataset_id = -1; + std::string base_path; + auto filesystem_wrapper_type = static_cast(FilesystemWrapperType::INVALID_FSW); + auto file_wrapper_type = static_cast(FileWrapperType::INVALID_FW); + std::string file_wrapper_config; + + session << "SELECT dataset_id, base_path, filesystem_wrapper_type, file_wrapper_type, file_wrapper_config FROM " + "datasets WHERE " + "name = :name", + soci::into(dataset_id), soci::into(base_path), soci::into(filesystem_wrapper_type), soci::into(file_wrapper_type), + soci::into(file_wrapper_config), soci::use(dataset_name); + + return {dataset_id, base_path, static_cast(filesystem_wrapper_type), + static_cast(file_wrapper_type), file_wrapper_config}; +} \ No newline at end of file diff --git a/modyn/storage/src/main.cpp b/modyn/storage/src/main.cpp new file mode 100644 index 000000000..65a520e6b --- /dev/null +++ b/modyn/storage/src/main.cpp @@ -0,0 +1,45 @@ +#include + +#include +#include +#include + +#include "modyn/utils/utils.hpp" +#include "storage_server.hpp" + +using namespace modyn::storage; + +void setup_logger() { spdlog::set_pattern("[%Y-%m-%d:%H:%M:%S] [%s:%#] [%l] [p%P:t%t] %v"); } + +argparse::ArgumentParser setup_argparser() { + argparse::ArgumentParser parser("Modyn Storage"); + + parser.add_argument("config").help("Modyn infrastructure configuration file"); + + return parser; +} + +int main(int argc, char* argv[]) { + /* Entrypoint for the storage service. */ + setup_logger(); + + auto parser = setup_argparser(); + + parser.parse_args(argc, argv); + + const std::string config_file = parser.get("config"); + + ASSERT(std::filesystem::exists(config_file), "Config file does not exist."); + + // Verify that the config file exists and is readable. + const YAML::Node config = YAML::LoadFile(config_file); + + SPDLOG_INFO("Initializing storage."); + StorageServer storage(config_file); + SPDLOG_INFO("Starting storage."); + storage.run(); + + SPDLOG_INFO("Storage returned, exiting."); + + return 0; +} diff --git a/modyn/storage/src/storage_server.cpp b/modyn/storage/src/storage_server.cpp new file mode 100644 index 000000000..dee0d8623 --- /dev/null +++ b/modyn/storage/src/storage_server.cpp @@ -0,0 +1,79 @@ +#include "storage_server.hpp" + +#include + +#include +#include +#include +#include + +#include "internal/file_watcher/file_watcher_watchdog.hpp" +#include "internal/grpc/storage_grpc_server.hpp" + +using namespace modyn::storage; + +void StorageServer::run() { + /* Run the storage service. */ + SPDLOG_INFO("Running storage service. Creating tables."); + + connection_.create_tables(); + SPDLOG_INFO("Running storage service. Initializing datasets from config."); + + for (const YAML::Node& dataset_node : config_["storage"]["datasets"]) { + const auto dataset_id = dataset_node["name"].as(); + const auto base_path = dataset_node["base_path"].as(); + const auto filesystem_wrapper_type = dataset_node["filesystem_wrapper_type"].as(); + const auto file_wrapper_type = dataset_node["file_wrapper_type"].as(); + const auto description = dataset_node["description"].as(); + const auto version = dataset_node["version"].as(); + const YAML::Node& file_wrapper_config_node = dataset_node["file_wrapper_config"]; + std::ostringstream fwc_stream; + fwc_stream << file_wrapper_config_node; + const std::string file_wrapper_config = fwc_stream.str(); + + SPDLOG_INFO("Parsed filewrapper_config: {}", file_wrapper_config); + + bool ignore_last_timestamp = false; + int file_watcher_interval = 5; + + if (dataset_node["ignore_last_timestamp"]) { + ignore_last_timestamp = dataset_node["ignore_last_timestamp"].as(); + } + + if (dataset_node["file_watcher_interval"]) { + file_watcher_interval = dataset_node["file_watcher_interval"].as(); + } + + const bool success = connection_.add_dataset( + dataset_id, base_path, FilesystemWrapper::get_filesystem_wrapper_type(filesystem_wrapper_type), + FileWrapper::get_file_wrapper_type(file_wrapper_type), description, version, file_wrapper_config, + ignore_last_timestamp, file_watcher_interval); + if (!success) { + SPDLOG_ERROR(fmt::format("Could not register dataset {} - potentially it already exists.", dataset_id)); + } + } + + SPDLOG_INFO("Starting file watcher watchdog."); + + // Start the file watcher watchdog + std::thread file_watcher_watchdog_thread(&FileWatcherWatchdog::run, &file_watcher_watchdog_); + + SPDLOG_INFO("Starting storage gRPC server."); + + // Start the storage grpc server + std::thread grpc_server_thread(&StorageGrpcServer::run, &grpc_server_); + + // Wait for shutdown signal (storage_shutdown_requested_ true) + storage_shutdown_requested_.wait(false); + + SPDLOG_INFO("Shutdown requested at storage server, requesting shutdown of gRPC server."); + + stop_grpc_server_.store(true); + grpc_server_thread.join(); + + SPDLOG_INFO("gRPC server stopped."); + + stop_file_watcher_watchdog_.store(true); + file_watcher_watchdog_thread.join(); + SPDLOG_INFO("Filewatcher stopped."); +} diff --git a/modyn/storage/storage.py b/modyn/storage/storage.py deleted file mode 100644 index c2e8e3176..000000000 --- a/modyn/storage/storage.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Storage module. - -The storage module contains all classes and functions related to the retrieval of data from the -various storage backends. -""" - -import json -import logging -import os -import pathlib -from ctypes import c_bool -from multiprocessing import Process, Value -from typing import Tuple - -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.file_watcher.new_file_watcher_watch_dog import run_watcher_watch_dog -from modyn.storage.internal.grpc.grpc_server import StorageGRPCServer -from modyn.utils import validate_yaml - -logger = logging.getLogger(__name__) - - -class Storage: - """Storage server. - - The storage server is responsible for the retrieval of data from the various storage backends. - """ - - def __init__(self, modyn_config: dict) -> None: - """Initialize the storage server. - - Args: - modyn_config (dict): Configuration of the modyn module. - - Raises: - ValueError: Invalid configuration. - """ - self.modyn_config = modyn_config - - valid, errors = self._validate_config() - if not valid: - raise ValueError(f"Invalid configuration: {errors}") - - def _validate_config(self) -> Tuple[bool, list[str]]: - schema_path = ( - pathlib.Path(os.path.abspath(__file__)).parent.parent / "config" / "schema" / "modyn_config_schema.yaml" - ) - return validate_yaml(self.modyn_config, schema_path) - - def run(self) -> None: - """Run the storage server. - - Raises: - ValueError: Failed to add dataset. - """ - #  Create the database tables. - with StorageDatabaseConnection(self.modyn_config) as database: - database.create_tables() - - for dataset in self.modyn_config["storage"]["datasets"]: - if not database.add_dataset( - dataset["name"], - dataset["base_path"], - dataset["filesystem_wrapper_type"], - dataset["file_wrapper_type"], - dataset["description"], - dataset["version"], - json.dumps(dataset["file_wrapper_config"]), - dataset["ignore_last_timestamp"] if "ignore_last_timestamp" in dataset else False, - dataset["file_watcher_interval"] if "file_watcher_interval" in dataset else 5, - ): - raise ValueError(f"Failed to add dataset {dataset['name']}") - - #  Start the dataset watcher process in a different thread. - should_stop = Value(c_bool, False) - watchdog = Process(target=run_watcher_watch_dog, args=(self.modyn_config, should_stop)) - watchdog.start() - - #  Start the storage grpc server. - with StorageGRPCServer(self.modyn_config) as server: - server.wait_for_termination() - - should_stop.value = True # type: ignore # See https://github.com/python/typeshed/issues/8799 - watchdog.join() diff --git a/modyn/storage/storage_entrypoint.py b/modyn/storage/storage_entrypoint.py deleted file mode 100644 index 194fc39b0..000000000 --- a/modyn/storage/storage_entrypoint.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Entrypoint for the storage service.""" - -import argparse -import logging -import multiprocessing as mp -import os -import pathlib - -import yaml -from modyn.storage.storage import Storage - -logging.basicConfig( - level=logging.NOTSET, - format="[%(asctime)s] [%(filename)15s:%(lineno)4d] %(levelname)-8s %(message)s", - datefmt="%Y-%m-%d:%H:%M:%S", -) -logger = logging.getLogger(__name__) - -# We need to do this at the top because other dependencies otherwise set fork. -try: - mp.set_start_method("spawn") -except RuntimeError as error: - if mp.get_start_method() != "spawn" and "PYTEST_CURRENT_TEST" not in os.environ: - logger.error("Start method is already set to {}", mp.get_start_method()) - raise error - - -def setup_argparser() -> argparse.ArgumentParser: - """Set up the argument parser. - - Returns: - argparse.ArgumentParser: Argument parser - """ - parser_ = argparse.ArgumentParser(description="Modyn Storage") - parser_.add_argument("config", type=pathlib.Path, action="store", help="Modyn infrastructure configuration file") - - return parser_ - - -def main() -> None: - """Entrypoint for the storage service.""" - parser = setup_argparser() - args = parser.parse_args() - - assert args.config.is_file(), f"File does not exist: {args.config}" - - with open(args.config, "r", encoding="utf-8") as config_file: - modyn_config = yaml.safe_load(config_file) - - logger.info("Initializing storage.") - storage = Storage(modyn_config) - logger.info("Starting storage.") - storage.run() - - logger.info("Storage returned, exiting.") - - -if __name__ == "__main__": - main() diff --git a/modyn/supervisor/internal/triggers/timetrigger.py b/modyn/supervisor/internal/triggers/timetrigger.py index 765e643c0..ed944164c 100644 --- a/modyn/supervisor/internal/triggers/timetrigger.py +++ b/modyn/supervisor/internal/triggers/timetrigger.py @@ -17,18 +17,18 @@ def __init__(self, trigger_config: dict): if not validate_timestr(timestr): raise ValueError(f"Invalid time string: {timestr}\nValid format is [s|m|h|d|w].") - self.trigger_every_ms: int = convert_timestr_to_seconds(trigger_config["trigger_every"]) * 1000 + self.trigger_every_s: int = convert_timestr_to_seconds(trigger_config["trigger_every"]) self.next_trigger_at: Optional[int] = None - if self.trigger_every_ms < 1: - raise ValueError(f"trigger_every must be > 0, but is {self.trigger_every_ms}") + if self.trigger_every_s < 1: + raise ValueError(f"trigger_every must be > 0, but is {self.trigger_every_s}") super().__init__(trigger_config) def inform(self, new_data: list[tuple[int, int, int]]) -> list[int]: if self.next_trigger_at is None: if len(new_data) > 0: - self.next_trigger_at = new_data[0][1] + self.trigger_every_ms # new_data is sorted + self.next_trigger_at = new_data[0][1] + self.trigger_every_s # new_data is sorted else: return [] @@ -44,9 +44,9 @@ def inform(self, new_data: list[tuple[int, int, int]]) -> list[int]: # This means that there was a trigger before the first item that we got informed about # However, there might have been multiple triggers, e.g., if there is one trigger every second # and 5 seconds have passed since the last item came through - # This is caught by our while loop which increases step by step for `trigger_every_ms`. + # This is caught by our while loop which increases step by step for `trigger_every_s`. triggering_indices.append(idx - 1) - self.next_trigger_at += self.trigger_every_ms + self.next_trigger_at += self.trigger_every_s return triggering_indices diff --git a/modyn/tests/CMakeLists.txt b/modyn/tests/CMakeLists.txt index 820c1eb00..45af7a3f6 100644 --- a/modyn/tests/CMakeLists.txt +++ b/modyn/tests/CMakeLists.txt @@ -11,10 +11,6 @@ set( utils/test_utils.hpp ) -add_library(modyn-test-utils-objs OBJECT ${MODYN_TEST_UTILS_SOURCES}) -target_include_directories(modyn-test-utils-objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/utils) -target_link_libraries(modyn-test-utils-objs PUBLIC gtest gmock spdlog fmt modyn example_extension trigger_sample_storage) - #################################################t # UNIT TESTS ################################################## @@ -26,7 +22,46 @@ set( common/trigger_sample/test_trigger_sample_storage.cpp ) -# TODO(MaxiBoether): When merging into storage, only add tests for storage when MODYN_BUILD_STORAGE is enabled +#################################################t +# STORAGE TESTS +################################################## +if (${MODYN_BUILD_STORAGE}) + message(STATUS "Storage is included in this test build.") + list( + APPEND MODYN_TEST_UTILS_SOURCES + + storage/storage_test_utils.cpp + storage/storage_test_utils.hpp + ) + + list( + APPEND MODYN_TEST_SOURCES + + storage/internal/file_watcher/file_watcher_test.cpp + storage/internal/file_watcher/file_watcher_watchdog_test.cpp + storage/internal/database/storage_database_connection_test.cpp + storage/internal/database/cursor_handler_test.cpp + storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp + storage/internal/file_wrapper/mock_file_wrapper.hpp + storage/internal/file_wrapper/binary_file_wrapper_test.cpp + storage/internal/file_wrapper/csv_file_wrapper_test.cpp + storage/internal/file_wrapper/file_wrapper_utils_test.cpp + storage/internal/filesystem_wrapper/local_filesystem_wrapper_test.cpp + storage/internal/filesystem_wrapper/mock_filesystem_wrapper.hpp + storage/internal/filesystem_wrapper/filesystem_wrapper_utils_test.cpp + storage/internal/grpc/storage_service_impl_test.cpp + ) + +endif () + +add_library(modyn-test-utils-objs OBJECT ${MODYN_TEST_UTILS_SOURCES}) +target_include_directories(modyn-test-utils-objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/utils) +target_link_libraries(modyn-test-utils-objs PUBLIC gtest gmock spdlog fmt yaml-cpp modyn example_extension trigger_sample_storage) + +if (${MODYN_BUILD_STORAGE}) + target_link_libraries(modyn-test-utils-objs PUBLIC modyn-storage-library) + target_include_directories(modyn-test-utils-objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/storage) +endif () add_library(modyn-test-objs OBJECT ${MODYN_TEST_SOURCES}) target_link_libraries(modyn-test-objs PRIVATE modyn-test-utils-objs) diff --git a/modyn/tests/storage/internal/database/cursor_handler_test.cpp b/modyn/tests/storage/internal/database/cursor_handler_test.cpp new file mode 100644 index 000000000..dc0273677 --- /dev/null +++ b/modyn/tests/storage/internal/database/cursor_handler_test.cpp @@ -0,0 +1,82 @@ +#include "internal/database/cursor_handler.hpp" + +#include +#include + +#include "test_utils.hpp" + +using namespace modyn::storage; + +class CursorHandlerTest : public ::testing::Test { + protected: + void SetUp() override { + modyn::test::TestUtils::create_dummy_yaml(); + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + connection.create_tables(); + + soci::session session = connection.get_session(); + + for (int64_t i = 0; i < 1000; i++) { + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, :file_id, :sample_index, " + ":label)", + soci::use(i, "file_id"), soci::use(i, "sample_index"), soci::use(i, "label"); + } + } + void TearDown() override { + if (std::filesystem::exists("test.db")) { + std::filesystem::remove("test.db"); + } + } +}; + +TEST_F(CursorHandlerTest, TestCheckCursorInitialized) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + + CursorHandler cursor_handler(session, connection.get_drivername(), "SELECT * FROM samples", "test_cursor"); + + ASSERT_NO_THROW(cursor_handler.close_cursor()); +} + +TEST_F(CursorHandlerTest, TestYieldPerSQLite3ThreeColumns) { // NOLINT (readability-function-cognitive-complexity) + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + + CursorHandler cursor_handler(session, connection.get_drivername(), + "SELECT sample_id, label, sample_index FROM samples", "test_cursor"); + + std::vector record(100); + for (int64_t i = 0; i < 10; i++) { + ASSERT_NO_THROW(record = cursor_handler.yield_per(100)); + ASSERT_EQ(record.size(), 100); + for (int64_t j = 0; j < 100; j++) { + ASSERT_EQ(record[j].id, j + i * 100 + 1); + ASSERT_EQ(record[j].column_1, j + i * 100); + ASSERT_EQ(record[j].column_2, j + i * 100); + } + } + cursor_handler.close_cursor(); +} + +TEST_F(CursorHandlerTest, TestYieldPerSQLite3TwoColumns) { // NOLINT (readability-function-cognitive-complexity) + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + + CursorHandler cursor_handler(session, connection.get_drivername(), "SELECT sample_id, label FROM samples", + "test_cursor", 2); + + std::vector record(100); + for (int64_t i = 0; i < 10; i++) { + ASSERT_NO_THROW(record = cursor_handler.yield_per(100)); + ASSERT_EQ(record.size(), 100); + for (int64_t j = 0; j < 100; j++) { + ASSERT_EQ(record[j].id, j + i * 100 + 1); + ASSERT_EQ(record[j].column_1, j + i * 100); + } + } + cursor_handler.close_cursor(); +} \ No newline at end of file diff --git a/modyn/tests/storage/internal/database/models/test_dataset.py b/modyn/tests/storage/internal/database/models/test_dataset.py deleted file mode 100644 index 5972bfc3e..000000000 --- a/modyn/tests/storage/internal/database/models/test_dataset.py +++ /dev/null @@ -1,115 +0,0 @@ -# pylint: disable=redefined-outer-name -import pytest -from modyn.storage.internal.database.models import Dataset, Sample -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - - -@pytest.fixture(autouse=True) -def session(): - engine = create_engine("sqlite:///:memory:", echo=True) - sess = sessionmaker(bind=engine)() - - Sample.ensure_pks_correct(sess) - Dataset.metadata.create_all(engine) - - yield sess - - sess.close() - engine.dispose() - - -def test_add_dataset(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - assert session.query(Dataset).filter(Dataset.name == "test").first() is not None - assert session.query(Dataset).filter(Dataset.name == "test").first().base_path == "test" - assert ( - session.query(Dataset).filter(Dataset.name == "test").first().filesystem_wrapper_type - == FilesystemWrapperType.LocalFilesystemWrapper - ) - assert ( - session.query(Dataset).filter(Dataset.name == "test").first().file_wrapper_type - == FileWrapperType.SingleSampleFileWrapper - ) - assert session.query(Dataset).filter(Dataset.name == "test").first().description == "test" - assert session.query(Dataset).filter(Dataset.name == "test").first().version == "test" - - -def test_update_dataset(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - session.query(Dataset).filter(Dataset.name == "test").update( - { - "base_path": "test2", - "file_wrapper_type": FileWrapperType.SingleSampleFileWrapper, - "description": "test2", - "version": "test2", - } - ) - session.commit() - - assert session.query(Dataset).filter(Dataset.name == "test").first().base_path == "test2" - assert ( - session.query(Dataset).filter(Dataset.name == "test").first().file_wrapper_type - == FileWrapperType.SingleSampleFileWrapper - ) - assert session.query(Dataset).filter(Dataset.name == "test").first().description == "test2" - assert session.query(Dataset).filter(Dataset.name == "test").first().version == "test2" - - -def test_repr(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - assert repr(dataset) == "" - - -def test_delete_dataset(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - session.query(Dataset).filter(Dataset.name == "test").delete() - session.commit() - - assert session.query(Dataset).filter(Dataset.name == "test").first() is None diff --git a/modyn/tests/storage/internal/database/models/test_file.py b/modyn/tests/storage/internal/database/models/test_file.py deleted file mode 100644 index d4dfb99a5..000000000 --- a/modyn/tests/storage/internal/database/models/test_file.py +++ /dev/null @@ -1,122 +0,0 @@ -# pylint: disable=redefined-outer-name -import pytest -from modyn.storage.internal.database.models import Dataset, File -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType -from modyn.utils import current_time_millis -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -NOW = current_time_millis() - - -@pytest.fixture(autouse=True) -def session(): - engine = create_engine("sqlite:///:memory:", echo=True) - sess = sessionmaker(bind=engine)() - - Dataset.metadata.create_all(engine) - File.metadata.create_all(engine) - - yield sess - - sess.close() - engine.dispose() - - -def test_add_file(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - assert session.query(File).filter(File.path == "test").first() is not None - assert session.query(File).filter(File.path == "test").first().dataset == dataset - assert session.query(File).filter(File.path == "test").first().created_at == now - assert session.query(File).filter(File.path == "test").first().updated_at == now - - -def test_update_file(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - now = NOW - - session.query(File).filter(File.path == "test").update({"path": "test2", "created_at": now, "updated_at": now}) - session.commit() - - assert session.query(File).filter(File.path == "test2").first() is not None - assert session.query(File).filter(File.path == "test2").first().dataset == dataset - assert session.query(File).filter(File.path == "test2").first().created_at == now - assert session.query(File).filter(File.path == "test2").first().updated_at == now - - -def test_delete_file(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - session.query(File).filter(File.path == "test").delete() - session.commit() - - assert session.query(File).filter(File.path == "test").first() is None - - -def test_repr_file(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - assert repr(file) == "" diff --git a/modyn/tests/storage/internal/database/models/test_sample.py b/modyn/tests/storage/internal/database/models/test_sample.py deleted file mode 100644 index 10247e2d1..000000000 --- a/modyn/tests/storage/internal/database/models/test_sample.py +++ /dev/null @@ -1,137 +0,0 @@ -# pylint: disable=redefined-outer-name -import pytest -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType -from modyn.utils import current_time_millis -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -NOW = current_time_millis() - - -@pytest.fixture(autouse=True) -def session(): - engine = create_engine("sqlite:///:memory:", echo=True) - sess = sessionmaker(bind=engine)() - Sample.ensure_pks_correct(sess) - - Dataset.metadata.create_all(engine) - - yield sess - - sess.close() - engine.dispose() - - -def test_add_sample(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - sample = Sample(dataset_id=dataset.dataset_id, file_id=file.file_id, index=0, label=b"test") - session.add(sample) - session.commit() - - sample_id = sample.sample_id - - assert session.query(Sample).filter(Sample.sample_id == sample_id).first() is not None - assert session.query(Sample).filter(Sample.sample_id == sample_id).first().file_id == file.file_id - assert session.query(Sample).filter(Sample.sample_id == sample_id).first().index == 0 - assert session.query(Sample).filter(Sample.sample_id == sample_id).first().label == b"test" - - -def test_update_sample(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - sample = Sample(dataset_id=dataset.dataset_id, file_id=file.file_id, index=0, label=b"test") - session.add(sample) - session.commit() - - sample_id = sample.sample_id - - session.query(Sample).filter(Sample.sample_id == sample_id).update({"index": 1}) - - assert session.query(Sample).filter(Sample.sample_id == sample_id).first().index == 1 - - -def test_delete_sample(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - sample = Sample(dataset_id=dataset.dataset_id, file_id=file.file_id, index=0, label=b"test") - session.add(sample) - session.commit() - - sample_id = sample.sample_id - - session.query(Sample).filter(Sample.sample_id == sample_id).delete() - - assert session.query(Sample).filter(Sample.sample_id == sample_id).first() is None - - -def test_repr(session): - dataset = Dataset( - name="test", - base_path="test", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - description="test", - version="test", - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - now = NOW - file = File(dataset=dataset, path="test", created_at=now, updated_at=now, number_of_samples=0) - session.add(file) - session.commit() - - sample = Sample(dataset_id=dataset.dataset_id, file_id=file.file_id, index=0, label=b"test") - session.add(sample) - session.commit() - - assert repr(sample) == "" diff --git a/modyn/tests/storage/internal/database/storage_database_connection_test.cpp b/modyn/tests/storage/internal/database/storage_database_connection_test.cpp new file mode 100644 index 000000000..f01b09c24 --- /dev/null +++ b/modyn/tests/storage/internal/database/storage_database_connection_test.cpp @@ -0,0 +1,138 @@ +#include "internal/database/storage_database_connection.hpp" + +#include +#include +#include + +#include + +#include "modyn/utils/utils.hpp" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +class StorageDatabaseConnectionTest : public ::testing::Test { + protected: + void TearDown() override { + if (std::filesystem::exists("test.db")) { + std::filesystem::remove("test.db"); + } + } +}; + +TEST_F(StorageDatabaseConnectionTest, TestGetSession) { + YAML::Node config = modyn::test::TestUtils::get_dummy_config(); // NOLINT + const StorageDatabaseConnection connection(config); + ASSERT_NO_THROW(connection.get_session()); +} + +TEST_F(StorageDatabaseConnectionTest, TestInvalidDriver) { + YAML::Node config = modyn::test::TestUtils::get_dummy_config(); // NOLINT + config["storage"]["database"]["drivername"] = "invalid"; + ASSERT_THROW(const StorageDatabaseConnection connection(config), modyn::utils::ModynException); +} + +TEST_F(StorageDatabaseConnectionTest, TestCreateTables) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + const StorageDatabaseConnection connection(config); + ASSERT_NO_THROW(connection.create_tables()); + + const StorageDatabaseConnection connection2(config); + soci::session session = connection2.get_session(); + + // Assert datasets, files and samples tables exist + int number_of_tables = 0; // NOLINT + session << "SELECT COUNT(*) FROM sqlite_master WHERE type='table';", soci::into(number_of_tables); + ASSERT_EQ(number_of_tables, 4); // 3 tables + 1 + // sqlite_sequence + // table +} + +TEST_F(StorageDatabaseConnectionTest, TestAddDataset) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + const StorageDatabaseConnection connection(config); + ASSERT_NO_THROW(connection.create_tables()); + + const StorageDatabaseConnection connection2(config); + soci::session session = connection2.get_session(); + + // Assert no datasets exist + int number_of_datasets = 0; // NOLINT + session << "SELECT COUNT(*) FROM datasets;", soci::into(number_of_datasets); + ASSERT_EQ(number_of_datasets, 0); + + // Add dataset + ASSERT_TRUE(connection2.add_dataset("test_dataset", "test_base_path", FilesystemWrapperType::LOCAL, + FileWrapperType::SINGLE_SAMPLE, "test_description", "test_version", + "test_file_wrapper_config", false, 0)); + + // Assert dataset exists + session << "SELECT COUNT(*) FROM datasets;", soci::into(number_of_datasets); + ASSERT_EQ(number_of_datasets, 1); + std::string dataset_name; // NOLINT + session << "SELECT name FROM datasets;", soci::into(dataset_name); + ASSERT_EQ(dataset_name, "test_dataset"); +} + +TEST_F(StorageDatabaseConnectionTest, TestAddExistingDataset) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + const StorageDatabaseConnection connection(config); + ASSERT_NO_THROW(connection.create_tables()); + + // Add dataset + ASSERT_TRUE(connection.add_dataset("test_dataset", "test_base_path", FilesystemWrapperType::LOCAL, + FileWrapperType::SINGLE_SAMPLE, "test_description", "test_version", + "test_file_wrapper_config", false, 0)); + + // Add existing dataset + ASSERT_FALSE(connection.add_dataset("test_dataset", "test_base_path2", FilesystemWrapperType::LOCAL, + FileWrapperType::SINGLE_SAMPLE, "test_description", "test_version", + "test_file_wrapper_config", false, 0)); + + soci::session session = connection.get_session(); + std::string base_path; + session << "SELECT base_path FROM datasets where name='test_dataset';", soci::into(base_path); + ASSERT_EQ(base_path, "test_base_path"); +} + +TEST_F(StorageDatabaseConnectionTest, TestDeleteDataset) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + const StorageDatabaseConnection connection(config); + ASSERT_NO_THROW(connection.create_tables()); + + const StorageDatabaseConnection connection2(config); + soci::session session = connection2.get_session(); + + // Assert no datasets exist + int number_of_datasets = 0; // NOLINT + session << "SELECT COUNT(*) FROM datasets;", soci::into(number_of_datasets); + ASSERT_EQ(number_of_datasets, 0); + + // Add dataset + ASSERT_NO_THROW(connection2.add_dataset("test_dataset", "test_base_path", FilesystemWrapperType::LOCAL, + FileWrapperType::SINGLE_SAMPLE, "test_description", "test_version", + "test_file_wrapper_config", false, 0)); + + // Assert dataset exists + session << "SELECT COUNT(*) FROM datasets;", soci::into(number_of_datasets); + ASSERT_EQ(number_of_datasets, 1); + + std::string dataset_name; // NOLINT + std::int64_t dataset_id; // NOLINT + session << "SELECT name, dataset_id FROM datasets;", soci::into(dataset_name), soci::into(dataset_id); + ASSERT_EQ(dataset_name, "test_dataset"); + + // Delete dataset + ASSERT_TRUE(connection2.delete_dataset("test_dataset", dataset_id)); + + // Assert no datasets exist + session << "SELECT COUNT(*) FROM datasets;", soci::into(number_of_datasets); + ASSERT_EQ(number_of_datasets, 0); +} + +TEST_F(StorageDatabaseConnectionTest, TestDeleteNonExistingDataset) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + const StorageDatabaseConnection connection(config); + ASSERT_NO_THROW(connection.create_tables()); +} diff --git a/modyn/tests/storage/internal/database/test_database_storage_utils.py b/modyn/tests/storage/internal/database/test_database_storage_utils.py deleted file mode 100644 index 883b7458a..000000000 --- a/modyn/tests/storage/internal/database/test_database_storage_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -from modyn.storage.internal.database.storage_database_utils import get_file_wrapper, get_filesystem_wrapper -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType, InvalidFileWrapperTypeException -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import ( - FilesystemWrapperType, - InvalidFilesystemWrapperTypeException, -) - - -def test_get_filesystem_wrapper(): - filesystem_wrapper = get_filesystem_wrapper(FilesystemWrapperType.LocalFilesystemWrapper, "/tmp/modyn") - assert filesystem_wrapper is not None - assert filesystem_wrapper.base_path == "/tmp/modyn" - assert filesystem_wrapper.filesystem_wrapper_type == FilesystemWrapperType.LocalFilesystemWrapper - - -def test_get_filesystem_wrapper_with_invalid_type(): - with pytest.raises(InvalidFilesystemWrapperTypeException): - filesystem_wrapper = get_filesystem_wrapper("invalid", "/tmp/modyn") - assert filesystem_wrapper is None - - -def test_get_file_wrapper(): - file_wrapper = get_file_wrapper(FileWrapperType.SingleSampleFileWrapper, "/tmp/modyn", "{}", None) - assert file_wrapper is not None - assert file_wrapper.file_wrapper_type == FileWrapperType.SingleSampleFileWrapper - - -def test_get_file_wrapper_with_invalid_type(): - with pytest.raises(InvalidFileWrapperTypeException): - file_wrapper = get_file_wrapper("invalid", "/tmp/modyn", "{}", None) - assert file_wrapper is None diff --git a/modyn/tests/storage/internal/database/test_storage_database_connection.py b/modyn/tests/storage/internal/database/test_storage_database_connection.py deleted file mode 100644 index d39940ed9..000000000 --- a/modyn/tests/storage/internal/database/test_storage_database_connection.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection - - -def get_minimal_modyn_config() -> dict: - return { - "storage": { - "database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": 0, - "database": ":memory:", - }, - } - } - - -def get_invalid_modyn_config() -> dict: - return { - "storage": { - "database": { - "drivername": "postgres", - "username": "", - "password": "", - "host": "", - "port": 10, - "database": "/tmp/modyn/modyn.db", - }, - } - } - - -def test_database_connection(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - database.create_tables() - assert database.session is not None - assert database.add_dataset("test", "/tmp/modyn", "local", "local", "test", "0.0.1", "{}") is True - - -def test_database_connection_with_existing_dataset(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - database.create_tables() - assert database.session is not None - assert ( - database.add_dataset( - "test", "/tmp/modyn", "LocalFilesystemWrapper", "SingleSampleFileWrapper", "test", "0.0.1", "{}" - ) - is True - ) - assert ( - database.add_dataset( - "test", "/tmp/modyn", "LocalFilesystemWrapper", "SingleSampleFileWrapper", "test", "0.0.1", "{}" - ) - is True - ) - - -def test_database_connection_with_existing_dataset_and_different_base_path(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - database.create_tables() - assert database.session is not None - assert ( - database.add_dataset( - "test", "/tmp/modyn", "LocalFilesystemWrapper", "SingleSampleFileWrapper", "test", "0.0.1", "{}" - ) - is True - ) - assert ( - database.add_dataset( - "test", "/tmp/modyn2", "LocalFilesystemWrapper", "SingleSampleFileWrapper", "test", "0.0.1", "{}" - ) - is True - ) - assert database.session.query(Dataset).filter(Dataset.name == "test").first().base_path == "/tmp/modyn2" - - -def test_database_connection_failure(): - with pytest.raises(Exception): - with StorageDatabaseConnection(get_invalid_modyn_config()) as database: - database.create_tables() - assert database.session is not None - assert ( - database.add_dataset( - "test", "/tmp/modyn", "LocalFilesystemWrapper", "SingleSampleFileWrapper", "test", "0.0.1", "{}" - ) - is True - ) - - -def test_add_dataset_failure(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - assert ( - database.add_dataset( - "test", "/tmp/modyn", "LocalFilesystemWrapper", "SingleSampleFileWrapper", "test", "0.0.1", "{}" - ) - is False - ) - - -def test_delete_dataset(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - database.create_tables() - assert database.session is not None - assert ( - database.add_dataset( - "test", "/tmp/modyn", "LocalFilesystemWrapper", "SingleSampleFileWrapper", "test", "0.0.1", "{}" - ) - is True - ) - dataset = database.session.query(Dataset).filter(Dataset.name == "test").first() - file = File(dataset=dataset, path="/tmp/modyn/test", created_at=0, updated_at=0, number_of_samples=1) - database.session.add(file) - database.session.commit() - file = database.session.query(File).filter(File.path == "/tmp/modyn/test").first() - sample = Sample(dataset_id=dataset.dataset_id, file_id=file.file_id, index=0, label=1) - database.session.add(sample) - database.session.commit() - assert database.delete_dataset("test") is True - assert database.session.query(Dataset).filter(Dataset.name == "test").first() is None - assert database.session.query(File).all() == [] - assert database.session.query(Sample).all() == [] - - -def test_delete_dataset_failure(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - assert database.delete_dataset("test") is False diff --git a/modyn/tests/storage/internal/file_watcher/file_watcher_test.cpp b/modyn/tests/storage/internal/file_watcher/file_watcher_test.cpp new file mode 100644 index 000000000..5ef855521 --- /dev/null +++ b/modyn/tests/storage/internal/file_watcher/file_watcher_test.cpp @@ -0,0 +1,458 @@ +#include "internal/file_watcher/file_watcher.hpp" + +#include +#include +#include +#include +#include + +#include + +#include "internal/database/storage_database_connection.hpp" +#include "internal/filesystem_wrapper/mock_filesystem_wrapper.hpp" +#include "modyn/utils/utils.hpp" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +class FileWatcherTest : public ::testing::Test { + protected: + std::string tmp_dir_; + + FileWatcherTest() : tmp_dir_{std::filesystem::temp_directory_path().string() + "/file_watcher_test"} {} + + void SetUp() override { + modyn::test::TestUtils::create_dummy_yaml(); + // Create temporary directory + std::filesystem::create_directory(tmp_dir_); + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + connection.create_tables(); + + // Add a dataset to the database + connection.add_dataset("test_dataset", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + } + + void TearDown() override { + modyn::test::TestUtils::delete_dummy_yaml(); + if (std::filesystem::exists("test.db")) { + std::filesystem::remove("test.db"); + } + // Remove temporary directory + std::filesystem::remove_all(tmp_dir_); + } +}; + +TEST_F(FileWatcherTest, TestConstructor) { + std::atomic stop_file_watcher = false; + ASSERT_NO_THROW(const FileWatcher watcher(YAML::LoadFile("config.yaml"), 1, &stop_file_watcher)); +} + +TEST_F(FileWatcherTest, TestSeek) { // NOLINT(readability-function-cognitive-complexity) + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + FileWatcher watcher(config, 1, &stop_file_watcher); + + const StorageDatabaseConnection connection(config); + + soci::session session = connection.get_session(); + + // Add a file to the temporary directory + const std::string test_file_path = tmp_dir_ + "/test_file.txt"; + std::ofstream test_file(test_file_path); + ASSERT(test_file.is_open(), "Could not open test file"); + test_file << "test"; + test_file.close(); + ASSERT(!test_file.is_open(), "Could not close test file"); + + const std::string label_file_path = tmp_dir_ + "/test_file.lbl"; + std::ofstream label_file(label_file_path); + ASSERT(label_file.is_open(), "Could not open label file"); + label_file << "1"; + label_file.close(); + ASSERT(!label_file.is_open(), "Could not close label file"); + + // Seek the temporary directory + ASSERT_NO_THROW(watcher.seek(session)); + + // Check if the file is added to the database + std::vector file_paths(1); + session << "SELECT path FROM files", soci::into(file_paths); + ASSERT_EQ(file_paths[0], test_file_path); + + // Check if the sample is added to the database + std::vector sample_ids(1); + session << "SELECT sample_id FROM samples", soci::into(sample_ids); + ASSERT_EQ(sample_ids[0], 1); + + // Assert the last timestamp of the dataset is updated + const int32_t dataset_id = 1; + int32_t last_timestamp; + session << "SELECT last_timestamp FROM datasets WHERE dataset_id = :id", soci::use(dataset_id), + soci::into(last_timestamp); + + ASSERT_TRUE(last_timestamp > 0); +} + +TEST_F(FileWatcherTest, TestSeekDataset) { // NOLINT(readability-function-cognitive-complexity) + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + FileWatcher watcher(config, 1, &stop_file_watcher); + + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + + // Add a file to the temporary directory + const std::string test_file_path = tmp_dir_ + "/test_file.txt"; + std::ofstream test_file(test_file_path); + ASSERT(test_file.is_open(), "Could not open test file"); + test_file << "test"; + test_file.close(); + ASSERT(!test_file.is_open(), "Could not close test file"); + + const std::string label_file_path = tmp_dir_ + "/test_file.lbl"; + std::ofstream label_file(label_file_path); + ASSERT(label_file.is_open(), "Could not open label file"); + label_file << "1"; + label_file.close(); + ASSERT(!label_file.is_open(), "Could not close label file"); + + ASSERT_NO_THROW(watcher.seek_dataset(session)); + + // Check if the file is added to the database + std::vector file_paths(1); + session << "SELECT path FROM files", soci::into(file_paths); + ASSERT_EQ(file_paths[0], test_file_path); + + // Check if the sample is added to the database + std::vector sample_ids(1); + session << "SELECT sample_id FROM samples", soci::into(sample_ids); + ASSERT_EQ(sample_ids[0], 1); +} + +TEST_F(FileWatcherTest, TestExtractCheckFileForInsertion) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + + const std::shared_ptr filesystem_wrapper = std::make_shared(); + + EXPECT_CALL(*filesystem_wrapper, get_modified_time(testing::_)).WillOnce(testing::Return(1000)); + + ASSERT_TRUE(FileWatcher::check_file_for_insertion("test.txt", false, 0, 1, filesystem_wrapper, session)); + + EXPECT_CALL(*filesystem_wrapper, get_modified_time(testing::_)).WillOnce(testing::Return(0)); + + ASSERT_FALSE(FileWatcher::check_file_for_insertion("test.txt", false, 1000, 1, filesystem_wrapper, session)); + + ASSERT_TRUE(FileWatcher::check_file_for_insertion("test.txt", true, 0, 1, filesystem_wrapper, session)); + + session << "INSERT INTO files (file_id, dataset_id, path, updated_at) VALUES " + "(1, 1, 'test.txt', 1000)"; + + ASSERT_FALSE(FileWatcher::check_file_for_insertion("test.txt", false, 0, 1, filesystem_wrapper, session)); + + ASSERT_FALSE(FileWatcher::check_file_for_insertion("test.txt", false, 1000, 1, filesystem_wrapper, session)); +} + +TEST_F(FileWatcherTest, TestUpdateFilesInDirectory) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + std::atomic stop_file_watcher = false; + FileWatcher watcher(config, 1, &stop_file_watcher); + + const std::shared_ptr filesystem_wrapper = std::make_shared(); + watcher.filesystem_wrapper = filesystem_wrapper; + + // Add a file to the temporary directory + const std::string test_file_path = tmp_dir_ + "/test.txt"; + std::ofstream test_file(test_file_path); + ASSERT(test_file.is_open(), "Could not open test file"); + test_file << "test"; + test_file.close(); + ASSERT(!test_file.is_open(), "Could not close test file"); + + const std::string label_file_path = tmp_dir_ + "/test.lbl"; + std::ofstream label_file(label_file_path); + ASSERT(label_file.is_open(), "Could not open label file"); + label_file << "1"; + label_file.close(); + ASSERT(!label_file.is_open(), "Could not close label file"); + + const std::vector files = {test_file_path, label_file_path}; + + EXPECT_CALL(*filesystem_wrapper, list(testing::_, testing::_, testing::_)).WillOnce(testing::Return(files)); + EXPECT_CALL(*filesystem_wrapper, get_modified_time(testing::_)).WillRepeatedly(testing::Return(1000)); + ON_CALL(*filesystem_wrapper, exists(testing::_)).WillByDefault(testing::Return(true)); + ON_CALL(*filesystem_wrapper, is_valid_path(testing::_)).WillByDefault(testing::Return(true)); + + ASSERT_NO_THROW(watcher.search_for_new_files_in_directory(tmp_dir_, 0)); + + std::vector file_paths(1); + session << "SELECT path FROM files", soci::into(file_paths); + ASSERT_EQ(file_paths[0], test_file_path); +} + +TEST_F(FileWatcherTest, TestFallbackInsertion) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + const FileWatcher watcher(config, 1, &stop_file_watcher); + + const StorageDatabaseConnection connection(config); + + soci::session session = connection.get_session(); + + std::vector files(3); + + // Add some files to the vector + files.push_back({1, 1, 1}); + files.push_back({2, 2, 2}); + files.push_back({3, 3, 3}); + + // Insert the files into the database + ASSERT_NO_THROW(FileWatcher::fallback_insertion(files, 1, session)); + + // Check if the files are added to the database + int32_t file_id = 1; + int32_t sample_id = -1; + session << "SELECT sample_id FROM samples WHERE file_id = :id", soci::use(file_id), soci::into(sample_id); + ASSERT_GT(sample_id, 0); + + file_id = 2; + sample_id = -1; + session << "SELECT sample_id FROM samples WHERE file_id = :id", soci::use(file_id), soci::into(sample_id); + ASSERT_GT(sample_id, 0); + + file_id = 3; + sample_id = -1; + session << "SELECT sample_id FROM samples WHERE file_id = :id", soci::use(file_id), soci::into(sample_id); + ASSERT_GT(sample_id, 0); +} + +TEST_F(FileWatcherTest, TestHandleFilePaths) { // NOLINT(readability-function-cognitive-complexity) + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + FileWatcher watcher(config, 1, &stop_file_watcher); + + // Add a file to the temporary directory + const std::string test_file_path = tmp_dir_ + "/test.txt"; + std::ofstream test_file(test_file_path); + ASSERT(test_file.is_open(), "Could not open test file"); + test_file << "test"; + test_file.close(); + ASSERT(!test_file.is_open(), "Could not close test file"); + + const std::string label_file_path = tmp_dir_ + "/test.lbl"; + std::ofstream label_file(label_file_path); + ASSERT(label_file.is_open(), "Could not open label file"); + label_file << "1"; + label_file.close(); + ASSERT(!label_file.is_open(), "Could not close label file"); + + const std::string test_file_path2 = tmp_dir_ + "/test2.txt"; + std::ofstream test_file2(test_file_path2); + ASSERT(test_file2.is_open(), "Could not open test file"); + test_file2 << "test"; + test_file2.close(); + ASSERT(!test_file2.is_open(), "Could not close test file"); + + const std::string label_file_path2 = tmp_dir_ + "/test2.lbl"; + std::ofstream label_file2(label_file_path2); + ASSERT(label_file2.is_open(), "Could not open label file"); + label_file2 << "2"; + label_file2.close(); + ASSERT(!label_file2.is_open(), "Could not close label file"); + + std::vector files = {test_file_path, test_file_path2}; + + const StorageDatabaseConnection connection(config); + + soci::session session = connection.get_session(); + + const std::shared_ptr filesystem_wrapper = std::make_shared(); + EXPECT_CALL(*filesystem_wrapper, get_modified_time(testing::_)).WillRepeatedly(testing::Return(1000)); + EXPECT_CALL(*filesystem_wrapper, exists(testing::_)).WillRepeatedly(testing::Return(true)); + watcher.filesystem_wrapper = filesystem_wrapper; + + const YAML::Node file_wrapper_config_node = YAML::Load(StorageTestUtils::get_dummy_file_wrapper_config_inline()); + + std::atomic exception_thrown = false; + ASSERT_NO_THROW(FileWatcher::handle_file_paths(files.begin(), files.end(), FileWrapperType::SINGLE_SAMPLE, 0, + FilesystemWrapperType::LOCAL, 1, &file_wrapper_config_node, &config, + 100, false, &exception_thrown)); + + // Check if the samples are added to the database + int32_t sample_id1 = -1; + int32_t label1 = -1; + int32_t file_id = 1; + session << "SELECT sample_id, label FROM samples WHERE file_id = :id", soci::use(file_id), soci::into(sample_id1), + soci::into(label1); + ASSERT_GT(sample_id1, 0); + ASSERT_EQ(label1, 1); + + int32_t sample_id2 = -1; + int32_t label2 = -1; + file_id = 2; + session << "SELECT sample_id, label FROM samples WHERE file_id = :id", soci::use(file_id), soci::into(sample_id2), + soci::into(label2); + ASSERT_GT(sample_id2, 0); + ASSERT_EQ(label2, 2); + + // Check if the files are added to the database + int32_t output_file_id = 0; + int32_t input_file_id = 1; + session << "SELECT file_id FROM files WHERE file_id = :id", soci::use(input_file_id), soci::into(output_file_id); + ASSERT_EQ(output_file_id, 1); + + input_file_id = 2; + session << "SELECT file_id FROM files WHERE file_id = :id", soci::use(input_file_id), soci::into(output_file_id); + ASSERT_EQ(output_file_id, 2); +} + +TEST_F(FileWatcherTest, TestConstructorWithInvalidInterval) { + std::atomic stop_file_watcher = false; + const FileWatcher watcher(YAML::LoadFile("config.yaml"), -1, &stop_file_watcher); + ASSERT_TRUE(watcher.stop_file_watcher->load()); +} + +TEST_F(FileWatcherTest, TestConstructorWithNullStopFileWatcher) { + ASSERT_THROW(const FileWatcher watcher(YAML::LoadFile("config.yaml"), 1, nullptr), modyn::utils::ModynException); +} + +TEST_F(FileWatcherTest, TestSeekWithNonExistentDirectory) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + std::atomic stop_file_watcher = false; + FileWatcher watcher(config, 1, &stop_file_watcher); + std::filesystem::remove_all(tmp_dir_); + + watcher.seek(session); +} + +TEST_F(FileWatcherTest, TestSeekDatasetWithNonExistentDirectory) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + const FileWatcher watcher(config, 1, &stop_file_watcher); + std::filesystem::remove_all(tmp_dir_); +} + +TEST_F(FileWatcherTest, TestCheckFileForInsertionWithEmptyPath) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + + const std::shared_ptr filesystem_wrapper = std::make_shared(); + + ASSERT_FALSE(FileWatcher::check_file_for_insertion("", false, 0, 1, filesystem_wrapper, session)); +} + +TEST_F(FileWatcherTest, TestFallbackInsertionWithEmptyVector) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const std::vector files; + + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + + ASSERT_NO_THROW(FileWatcher::fallback_insertion(files, 1, session)); +} + +TEST_F(FileWatcherTest, TestHandleFilePathsWithEmptyVector) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + std::vector files; + + const YAML::Node file_wrapper_config_node = YAML::Load(StorageTestUtils::get_dummy_file_wrapper_config_inline()); + + std::atomic exception_thrown = false; + ASSERT_NO_THROW(FileWatcher::handle_file_paths(files.begin(), files.end(), FileWrapperType::SINGLE_SAMPLE, 0, + FilesystemWrapperType::LOCAL, 1, &file_wrapper_config_node, &config, + 100, false, &exception_thrown)); +} + +TEST_F(FileWatcherTest, TestMultipleFileHandling) { // NOLINT(readability-function-cognitive-complexity) + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + std::atomic stop_file_watcher = false; + FileWatcher watcher(config, 1, &stop_file_watcher); + + const int16_t number_of_files = 10; + + // Add several files to the temporary directory + for (int i = 0; i < number_of_files; i++) { + const std::string test_file_path = tmp_dir_ + "/test_file" + std::to_string(i) + ".txt"; + std::ofstream test_file(test_file_path); + ASSERT(test_file.is_open(), "Could not open test file"); + test_file << "test"; + test_file.close(); + ASSERT(!test_file.is_open(), "Could not close test file"); + + const std::string label_file_path = tmp_dir_ + "/test_file" + std::to_string(i) + ".lbl"; + std::ofstream label_file(label_file_path); + ASSERT(label_file.is_open(), "Could not open label file"); + label_file << i; + label_file.close(); + ASSERT(!label_file.is_open(), "Could not close label file"); + } + + // Seek the temporary directory + ASSERT_NO_THROW(watcher.seek(session)); + + // Check if the files are added to the database + std::vector file_paths(number_of_files); + session << "SELECT path FROM files", soci::into(file_paths); + + // Make sure all files were detected and processed + for (int i = 0; i < number_of_files; i++) { + ASSERT_TRUE(std::find(file_paths.begin(), file_paths.end(), tmp_dir_ + "/test_file" + std::to_string(i) + ".txt") != + file_paths.end()); + } +} + +TEST_F(FileWatcherTest, TestDirectoryUpdateWhileRunning) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + soci::session session = connection.get_session(); + std::atomic stop_file_watcher = false; + FileWatcher watcher(config, 1, &stop_file_watcher); + + std::thread watcher_thread([&watcher, &stop_file_watcher, &session]() { + while (!stop_file_watcher) { + watcher.seek(session); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + }); + + // Add a file to the temporary directory + const std::string test_file_path = tmp_dir_ + "/test_file1.txt"; + std::ofstream test_file(test_file_path); + ASSERT(test_file.is_open(), "Could not open test file"); + test_file << "test"; + test_file.close(); + ASSERT(!test_file.is_open(), "Could not close test file"); + + const std::string label_file_path = tmp_dir_ + "/test_file1.lbl"; + std::ofstream label_file(label_file_path); + ASSERT(label_file.is_open(), "Could not open label file"); + label_file << "1"; + label_file.close(); + ASSERT(!label_file.is_open(), "Could not close label file"); + + std::this_thread::sleep_for(std::chrono::seconds(2)); // wait for the watcher to process + + stop_file_watcher = true; // Need to stop the file watcher as sqlite3 can't handle multiple threads accessing the + // database at the same time + watcher_thread.join(); + + // Check if the file is added to the database + std::string file_path; + session << "SELECT path FROM files WHERE file_id=1", soci::into(file_path); + ASSERT_EQ(file_path, test_file_path); +} diff --git a/modyn/tests/storage/internal/file_watcher/file_watcher_watchdog_test.cpp b/modyn/tests/storage/internal/file_watcher/file_watcher_watchdog_test.cpp new file mode 100644 index 000000000..c32f41d12 --- /dev/null +++ b/modyn/tests/storage/internal/file_watcher/file_watcher_watchdog_test.cpp @@ -0,0 +1,282 @@ +#include "internal/file_watcher/file_watcher_watchdog.hpp" + +#include +#include + +#include + +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +class FileWatcherWatchdogTest : public ::testing::Test { + protected: + std::string tmp_dir_; + + FileWatcherWatchdogTest() + : tmp_dir_{std::filesystem::temp_directory_path().string() + "/file_watcher_watchdog_test"} {} + + void SetUp() override { + modyn::test::TestUtils::create_dummy_yaml(); + // Create temporary directory + std::filesystem::create_directory(tmp_dir_); + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + connection.create_tables(); + } + + void TearDown() override { + modyn::test::TestUtils::delete_dummy_yaml(); + if (std::filesystem::exists("test.db")) { + std::filesystem::remove("test.db"); + } + // Remove temporary directory + std::filesystem::remove_all(tmp_dir_); + } +}; + +TEST_F(FileWatcherWatchdogTest, TestConstructor) { + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + const YAML::Node config = YAML::LoadFile("config.yaml"); + ASSERT_NO_THROW(const FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown)); +} + +TEST_F(FileWatcherWatchdogTest, TestRun) { + // Collect the output of the watchdog + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + + const std::shared_ptr watchdog = + std::make_shared(config, &stop_file_watcher, &request_shutdown); + + std::thread th(&FileWatcherWatchdog::run, watchdog); + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + + stop_file_watcher = true; + th.join(); + + // Check if the watchdog has stopped + ASSERT_FALSE(th.joinable()); +} + +TEST_F(FileWatcherWatchdogTest, TestStartFileWatcherProcess) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + + const StorageDatabaseConnection connection(config); + + // Add two dataset to the database + connection.add_dataset("test_dataset1", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", + modyn::storage::StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + connection.add_dataset("test_dataset2", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", + modyn::storage::StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + + watchdog.start_file_watcher_thread(1); + std::vector file_watcher_threads; + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + ASSERT_EQ(file_watcher_threads.size(), 1); + + // Test if the file watcher process is still running + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + ASSERT_EQ(file_watcher_threads.size(), 1); + + watchdog.stop_file_watcher_thread(1); + watchdog.start_file_watcher_thread(1); + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + ASSERT_EQ(file_watcher_threads.size(), 1); + + watchdog.stop_file_watcher_thread(1); +} + +TEST_F(FileWatcherWatchdogTest, TestStopFileWatcherProcess) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + + const StorageDatabaseConnection connection(config); + + connection.add_dataset("test_dataset", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", + modyn::storage::StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + + watchdog.start_file_watcher_thread(1); + + std::vector file_watcher_threads; + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 1); + + watchdog.stop_file_watcher_thread(1); + + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 0); +} + +TEST_F(FileWatcherWatchdogTest, TestWatchFileWatcherThreads) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + + const StorageDatabaseConnection connection(config); + + watchdog.watch_file_watcher_threads(); + + connection.add_dataset("test_dataset1", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", + modyn::storage::StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + + watchdog.watch_file_watcher_threads(); + + std::vector file_watcher_threads; + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 1); + + watchdog.watch_file_watcher_threads(); + + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 1); + ASSERT_EQ(file_watcher_threads[0], 1); + + watchdog.stop_file_watcher_thread(1); + + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 0); + + watchdog.watch_file_watcher_threads(); + + file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 1); + + watchdog.stop_file_watcher_thread(1); +} + +TEST_F(FileWatcherWatchdogTest, TestFileWatcherWatchdogWithNoDataset) { + // This test ensures that the watchdog handles correctly the situation where there is no dataset in the database + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + const StorageDatabaseConnection connection(config); + + watchdog.watch_file_watcher_threads(); + + // Assert that there are no running FileWatcher threads as there are no datasets + const std::vector file_watcher_threads = watchdog.get_running_file_watcher_threads(); + ASSERT_TRUE(file_watcher_threads.empty()); +} + +TEST_F(FileWatcherWatchdogTest, TestRestartFailedFileWatcherProcess) { + // This test checks that the watchdog successfully restarts a failed FileWatcher process + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + const StorageDatabaseConnection connection(config); + + connection.add_dataset("test_dataset", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", + modyn::storage::StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + + watchdog.start_file_watcher_thread(1); + // Simulate a failure of the FileWatcher process + watchdog.stop_file_watcher_thread(1); + + // The watchdog should detect the failure and restart the process + watchdog.watch_file_watcher_threads(); + + std::vector file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 1); + ASSERT_EQ(file_watcher_threads[0], 1); + watchdog.stop_file_watcher_thread(1); +} + +TEST_F(FileWatcherWatchdogTest, TestAddingNewDataset) { + // This test checks that the watchdog successfully starts a FileWatcher process for a new dataset + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + const StorageDatabaseConnection connection(config); + + watchdog.watch_file_watcher_threads(); + + // Add a new dataset to the database + connection.add_dataset("test_dataset", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", + modyn::storage::StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + + // The watchdog should start a FileWatcher process for the new dataset + watchdog.watch_file_watcher_threads(); + + std::vector file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_EQ(file_watcher_threads.size(), 1); + ASSERT_EQ(file_watcher_threads[0], 1); + watchdog.stop_file_watcher_thread(1); +} + +TEST_F(FileWatcherWatchdogTest, TestRemovingDataset) { + // This test checks that the watchdog successfully stops a FileWatcher process for a removed dataset + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + const StorageDatabaseConnection connection(config); + + // Add a new dataset to the database + connection.add_dataset("test_dataset", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", + modyn::storage::StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + + watchdog.watch_file_watcher_threads(); + + // The watchdog should start a FileWatcher process for the new dataset + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // Now remove the dataset from the database + connection.delete_dataset("test_dataset", 1); + + // The watchdog should stop the FileWatcher process for the removed dataset + watchdog.watch_file_watcher_threads(); + + const std::vector file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_TRUE(file_watcher_threads.empty()); +} + +TEST_F(FileWatcherWatchdogTest, TestNoDatasetsInDB) { + // This test checks that the watchdog does not start any FileWatcher threads if there are no datasets + const YAML::Node config = YAML::LoadFile("config.yaml"); + std::atomic stop_file_watcher = false; + std::atomic request_shutdown = false; + FileWatcherWatchdog watchdog(config, &stop_file_watcher, &request_shutdown); + const StorageDatabaseConnection connection(config); + + watchdog.watch_file_watcher_threads(); + + const std::vector file_watcher_threads = watchdog.get_running_file_watcher_threads(); + + ASSERT_TRUE(file_watcher_threads.empty()); +} diff --git a/modyn/tests/storage/internal/file_watcher/test_new_file_watcher.py b/modyn/tests/storage/internal/file_watcher/test_new_file_watcher.py deleted file mode 100644 index 271817850..000000000 --- a/modyn/tests/storage/internal/file_watcher/test_new_file_watcher.py +++ /dev/null @@ -1,785 +0,0 @@ -# pylint: disable=unused-argument, redefined-outer-name -import os -import pathlib -import shutil -import time -import typing -from ctypes import c_bool -from multiprocessing import Process, Value -from unittest.mock import patch - -import pytest -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.file_watcher.new_file_watcher import NewFileWatcher, run_new_file_watcher -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.abstract_filesystem_wrapper import AbstractFileSystemWrapper -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType - -FILE_TIMESTAMP = 1600000000 -TEST_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp") -TEST_FILE1 = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "test1.txt") -TEST_FILE2 = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "test2.txt") -TEST_FILE_WRONG_SUFFIX = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "test1.csv") -TEST_DATABASE = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "test.db") - - -def get_minimal_modyn_config() -> dict: - return { - "storage": { - "database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": 0, - "database": TEST_DATABASE, - }, - "insertion_threads": 8, - } - } - - -def get_invalid_modyn_config() -> dict: - return { - "storage": { - "database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": 0, - "database": TEST_DATABASE, - } - } - } - - -def setup(): - os.makedirs(TEST_DIR, exist_ok=True) - with open(TEST_FILE1, "w", encoding="utf-8") as file: - file.write("test") - with open(TEST_FILE2, "w", encoding="utf-8") as file: - file.write("test") - - -def teardown(): - shutil.rmtree(TEST_DIR) - - -@pytest.fixture(autouse=True) -def storage_database_connection(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - database.create_tables() - yield database - database.session.query(Dataset).delete() - database.session.query(File).delete() - database.session.query(Sample).delete() - database.session.commit() - - -class MockFileSystemWrapper(AbstractFileSystemWrapper): - def __init__(self): - super().__init__(TEST_DIR) - self._list = [TEST_FILE1, TEST_FILE2, TEST_FILE_WRONG_SUFFIX] - self._list_called = False - - def exists(self, path: str) -> bool: - if path == "/notexists": - return False - return True - - def isdir(self, path: str) -> bool: - if path in (TEST_FILE1, TEST_FILE2, TEST_FILE_WRONG_SUFFIX): - return False - if path == TEST_DIR: - return True - return False - - def isfile(self, path: str) -> bool: - if path in (TEST_FILE1, TEST_FILE2, TEST_FILE_WRONG_SUFFIX): - return True - return False - - def list(self, path: str, recursive: bool = False) -> list[str]: - self._list_called = True - return self._list - - def join(self, *paths: str) -> str: - return "/".join(paths) - - def get_modified(self, path: str) -> int: - return FILE_TIMESTAMP - - def get_created(self, path: str) -> int: - return FILE_TIMESTAMP - - def _get(self, path: str) -> typing.BinaryIO: - return typing.BinaryIO() - - def get_size(self, path: str) -> int: - return 2 - - def get_list_called(self) -> bool: - return self._list_called - - def delete(self, path: str) -> None: - return - - -class MockFileWrapper: - def get_number_of_samples(self) -> int: - return 2 - - def get_label(self, index: int) -> bytes: - return b"test" - - def get_all_labels(self) -> list[bytes]: - return [b"test", b"test"] - - -class MockDataset: - def __init__(self): - self.filesystem_wrapper_type = "mock" - self.base_path = TEST_DIR - - -class MockFile: - def __init__(self): - self.path = TEST_FILE1 - self.timestamp = FILE_TIMESTAMP - - -class MockQuery: - def __init__(self): - self._all = [MockFile()] - - def all(self) -> list[MockFile]: - return self._all - - -@patch.object(NewFileWatcher, "_seek_dataset", return_value=None) -def test_seek(test__seek_dataset, storage_database_connection) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test1", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - session.add(dataset) - session.commit() - - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - - session.add( - File(dataset=dataset, path="/tmp/modyn/test", created_at=0, updated_at=FILE_TIMESTAMP + 10, number_of_samples=1) - ) - session.commit() - - new_file_watcher._seek(storage_database_connection, dataset) - assert test__seek_dataset.called - assert session.query(Dataset).first().last_timestamp == FILE_TIMESTAMP + 10 - - -@patch.object(NewFileWatcher, "_update_files_in_directory", return_value=None) -def test_seek_dataset(test__update_files_in_directory, storage_database_connection) -> None: # noqa: E501 - should_stop = Value(c_bool, False) - - session = storage_database_connection.session - - session.add( - Dataset( - name="test2", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - ) - session.commit() - dataset = session.query(Dataset).first() - - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - - new_file_watcher._seek_dataset(session, dataset) - assert test__update_files_in_directory.called - - -def test_seek_dataset_deleted(storage_database_connection) -> None: # noqa: E501 - should_stop = Value(c_bool, False) - - session = storage_database_connection.session - - session.add( - Dataset( - name="test2", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - file_wrapper_config='{"file_extension": ".txt"}', - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - ) - session.commit() - - dataset = session.query(Dataset).first() - session.add( - File(dataset=dataset, path="/tmp/modyn/test", created_at=0, updated_at=FILE_TIMESTAMP + 10, number_of_samples=1) - ) - session.commit() - - process = Process(target=NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop).run) - process.start() - - start = time.time() - - time.sleep(1) - - session.delete(dataset) - session.commit() - - while time.time() - start < 5: - if not process.is_alive(): - break - time.sleep(0.1) - - assert not process.is_alive() - - -@patch.object(NewFileWatcher, "_update_files_in_directory", return_value=None) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_seek_path_not_exists( - test_get_filesystem_wrapper, test__update_files_in_directory, storage_database_connection -) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test1", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path="/notexists", - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - session.add(dataset) - session.commit() - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - session.add( - File(dataset=dataset, path="/tmp/modyn/test", created_at=0, updated_at=FILE_TIMESTAMP + 10, number_of_samples=1) - ) - session.commit() - - new_file_watcher._seek(storage_database_connection, dataset) - assert not test__update_files_in_directory.called - assert session.query(Dataset).first().last_timestamp == FILE_TIMESTAMP + 10 - - -@patch.object(NewFileWatcher, "_update_files_in_directory", return_value=None) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_seek_path_not_directory( - test_get_filesystem_wrapper, test__update_files_in_directory, storage_database_connection -) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test1", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_FILE1, - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - session.add(dataset) - session.commit() - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - session.add( - File(dataset=dataset, path="/tmp/modyn/test", created_at=0, updated_at=FILE_TIMESTAMP + 10, number_of_samples=1) - ) - session.commit() - - new_file_watcher._seek(storage_database_connection, dataset) - assert not test__update_files_in_directory.called - assert session.query(Dataset).first().last_timestamp == FILE_TIMESTAMP + 10 - - -@patch.object(NewFileWatcher, "_update_files_in_directory", return_value=None) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_seek_no_datasets( - test_get_filesystem_wrapper, test__update_files_in_directory, storage_database_connection -) -> None: # noqa: E501 - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), 1, should_stop) - - new_file_watcher._seek(storage_database_connection, None) - assert not test__update_files_in_directory.called - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher.get_file_wrapper", return_value=MockFileWrapper()) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_update_files_in_directory( - test_get_file_wrapper, test_get_filesystem_wrapper, storage_database_connection -) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test5", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - file_wrapper_config='{"file_extension": ".txt"}', - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - session.add(dataset) - session.commit() - - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - - new_file_watcher._update_files_in_directory( - filesystem_wrapper=MockFileSystemWrapper(), - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - path=TEST_DIR, - timestamp=FILE_TIMESTAMP - 1, - session=session, - dataset=dataset, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - new_file_watcher._update_files_in_directory( - filesystem_wrapper=MockFileSystemWrapper(), - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - path=TEST_DIR, - timestamp=FILE_TIMESTAMP - 1, - session=session, - dataset=dataset, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher.get_file_wrapper", return_value=MockFileWrapper()) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_update_files_in_directory_mt_disabled( - test_get_file_wrapper, test_get_filesystem_wrapper, storage_database_connection -) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test5", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - file_wrapper_config='{"file_extension": ".txt"}', - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - session.add(dataset) - session.commit() - - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - new_file_watcher._disable_mt = True - - new_file_watcher._update_files_in_directory( - filesystem_wrapper=MockFileSystemWrapper(), - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - path=TEST_DIR, - timestamp=FILE_TIMESTAMP - 1, - session=session, - dataset=dataset, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - new_file_watcher._update_files_in_directory( - filesystem_wrapper=MockFileSystemWrapper(), - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - path=TEST_DIR, - timestamp=FILE_TIMESTAMP - 1, - session=session, - dataset=dataset, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher.get_file_wrapper", return_value=MockFileWrapper()) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_handle_file_paths_presupplied_config( - test_get_file_wrapper, test_get_filesystem_wrapper, storage_database_connection -) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test_handle_file_paths", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - file_wrapper_config='{"file_extension": ".txt"}', - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - - session.add(dataset) - session.commit() - - file_paths = MockFileSystemWrapper().list(TEST_DIR, recursive=True) - new_file_watcher._handle_file_paths( - -1, - 1234, - False, - False, - file_paths, - get_minimal_modyn_config(), - ".txt", - MockFileSystemWrapper(), - "fw", - FILE_TIMESTAMP - 1, - "test_handle_file_paths", - 1, - session, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - new_file_watcher._handle_file_paths( - -1, - 1234, - False, - False, - file_paths, - get_minimal_modyn_config(), - ".txt", - MockFileSystemWrapper(), - "fw", - FILE_TIMESTAMP - 1, - "test_handle_file_paths", - 1, - session, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher.get_file_wrapper", return_value=MockFileWrapper()) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_handle_file_paths_no_presupplied_config( - test_get_file_wrapper, test_get_filesystem_wrapper, storage_database_connection -) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test_handle_file_paths", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - file_wrapper_config='{"file_extension": ".txt"}', - last_timestamp=FILE_TIMESTAMP - 1, - file_watcher_interval=0.1, - ) - - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - - session.add(dataset) - session.commit() - - file_paths = MockFileSystemWrapper().list(TEST_DIR, recursive=True) - new_file_watcher._handle_file_paths( - -1, - 1234, - False, - False, - file_paths, - get_minimal_modyn_config(), - ".txt", - MockFileSystemWrapper(), - "fw", - FILE_TIMESTAMP - 1, - "test_handle_file_paths", - 1, - None, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - new_file_watcher._handle_file_paths( - -1, - 1234, - False, - False, - file_paths, - get_minimal_modyn_config(), - ".txt", - MockFileSystemWrapper(), - "fw", - FILE_TIMESTAMP - 1, - "test_handle_file_paths", - 1, - None, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher.get_file_wrapper", return_value=MockFileWrapper()) -@patch( - "modyn.storage.internal.file_watcher.new_file_watcher.get_filesystem_wrapper", return_value=MockFileSystemWrapper() -) -def test_update_files_in_directory_ignore_last_timestamp( - test_get_file_wrapper, test_get_filesystem_wrapper, storage_database_connection -) -> None: # noqa: E501 - session = storage_database_connection.session - dataset = Dataset( - name="test6", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - file_wrapper_config='{"file_extension": ".txt"}', - last_timestamp=FILE_TIMESTAMP - 1, - ignore_last_timestamp=True, - file_watcher_interval=0.1, - ) - session.add(dataset) - session.commit() - - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - - new_file_watcher._update_files_in_directory( - filesystem_wrapper=MockFileSystemWrapper(), - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - path=TEST_DIR, - timestamp=FILE_TIMESTAMP + 10, - session=session, - dataset=dataset, - ) - - result = session.query(File).all() - assert result is not None - assert len(result) == 2 - assert result[0].path == TEST_FILE1 - assert result[0].created_at == FILE_TIMESTAMP - assert result[0].number_of_samples == 2 - assert result[0].dataset_id == 1 - - result = session.query(Sample).all() - assert result is not None - assert len(result) == 4 - assert result[0].file_id == 1 - - -def test_update_files_in_directory_not_exists(storage_database_connection) -> None: - session = storage_database_connection.session - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), 1, should_stop) - mock_file_system_wrapper = MockFileSystemWrapper() - new_file_watcher._update_files_in_directory( - filesystem_wrapper=mock_file_system_wrapper, - file_wrapper_type=MockFileWrapper(), - path="/notexists", - timestamp=FILE_TIMESTAMP - 1, - session=session, - dataset=MockDataset(), - ) - assert not mock_file_system_wrapper.get_list_called() - - -@patch.object(NewFileWatcher, "_seek", return_value=None) -def test_run(mock_seek, storage_database_connection) -> None: - session = storage_database_connection.session - dataset = Dataset( - name="test7", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - file_wrapper_config='{"file_extension": ".txt"}', - last_timestamp=-1, - ) - session.add(dataset) - session.commit() - - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), dataset.dataset_id, should_stop) - watcher_process = Process(target=new_file_watcher.run, args=()) - watcher_process.start() - should_stop.value = True # type: ignore - watcher_process.join() - #  If we get here, the process has stopped - - -def test_get_datasets(storage_database_connection): - session = storage_database_connection.session - should_stop = Value(c_bool, False) - new_file_watcher = NewFileWatcher(get_minimal_modyn_config(), 1, should_stop) - datasets = new_file_watcher._get_datasets(session) - assert len(datasets) == 0 - - dataset = Dataset( - name="test_get_datasets", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - last_timestamp=FILE_TIMESTAMP - 1, - file_wrapper_config='{"file_extension": ".txt"}', - file_watcher_interval=0.1, - ignore_last_timestamp=True, - ) - session.add(dataset) - session.commit() - - datasets: list[Dataset] = new_file_watcher._get_datasets(session) - assert len(datasets) == 1 - assert datasets[0].name == "test_get_datasets" - - -def test_run_new_file_watcher(storage_database_connection): - session = storage_database_connection.session - should_stop = Value(c_bool, False) - - dataset = Dataset( - name="test8", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path=TEST_DIR, - last_timestamp=FILE_TIMESTAMP - 1, - file_wrapper_config='{"file_extension": ".txt"}', - file_watcher_interval=0.1, - ignore_last_timestamp=True, - ) - session.add(dataset) - session.commit() - - Process(target=run_new_file_watcher, args=(get_minimal_modyn_config(), dataset.dataset_id, should_stop)).start() - - time.sleep(2) # If this test fails, try increasing this number - should_stop.value = True # type: ignore - - result = session.query(File).filter(File.path == TEST_FILE1).all() - assert result is not None - assert len(result) == 1 - assert result[0].path == TEST_FILE1 - assert result[0].number_of_samples == 1 - assert result[0].dataset_id == 1 diff --git a/modyn/tests/storage/internal/file_watcher/test_new_file_watcher_watch_dog.py b/modyn/tests/storage/internal/file_watcher/test_new_file_watcher_watch_dog.py deleted file mode 100644 index 3817de1a1..000000000 --- a/modyn/tests/storage/internal/file_watcher/test_new_file_watcher_watch_dog.py +++ /dev/null @@ -1,182 +0,0 @@ -# pylint: disable=unused-argument, redefined-outer-name -import os -import pathlib -import shutil -import typing -from ctypes import c_bool -from multiprocessing import Process, Value -from unittest.mock import patch - -import pytest -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.file_watcher.new_file_watcher_watch_dog import NewFileWatcherWatchDog -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType - -TEST_DATABASE = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "test.db") -TEST_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp") -TEST_FILE1 = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "test1.txt") - - -def get_minimal_modyn_config() -> dict: - return { - "storage": { - "database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": 0, - "database": TEST_DATABASE, - }, - } - } - - -def get_invalid_modyn_config() -> dict: - return { - "storage": { - "database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": 0, - "database": TEST_DATABASE, - }, - } - } - - -def setup(): - os.makedirs(TEST_DIR, exist_ok=True) - with open(TEST_FILE1, "w", encoding="utf-8") as file: - file.write("test") - - -def teardown(): - shutil.rmtree(TEST_DIR) - - -@pytest.fixture(autouse=True) -def session(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - database.create_tables() - yield database.session - database.session.query(Dataset).delete() - database.session.query(File).delete() - database.session.query(Sample).delete() - database.session.commit() - - -class MockProcess(Process): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._running = Value(c_bool, True) - - def is_alive(self): - return self._running.value - - def terminate(self): - self._running.value = False - - def join(self, timeout: typing.Optional[float] = ...) -> None: - pass - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher_watch_dog.Process", return_value=MockProcess()) -def test_start_file_watcher(mock_process, session): - should_stop = Value(c_bool, False) - new_file_watcher_watch_dog = NewFileWatcherWatchDog(get_minimal_modyn_config(), should_stop) - new_file_watcher_watch_dog._start_file_watcher_process(1) - - assert new_file_watcher_watch_dog._file_watcher_processes[1][0] is not None - - -def test_stop_file_watcher_process(session): - should_stop = Value(c_bool, False) - new_file_watcher_watch_dog = NewFileWatcherWatchDog(get_minimal_modyn_config(), should_stop) - - mock_process = MockProcess() - - should_stop = Value(c_bool, False) - - new_file_watcher_watch_dog._file_watcher_processes[1] = (mock_process, should_stop, 0) - - new_file_watcher_watch_dog._stop_file_watcher_process(1) - - assert not mock_process.is_alive() - assert should_stop.value - - -def test_watch_file_watcher_processes_dataset_not_in_database(session): - should_stop = Value(c_bool, False) - new_file_watcher_watch_dog = NewFileWatcherWatchDog(get_minimal_modyn_config(), should_stop) - - mock_process = MockProcess() - - should_stop = Value(c_bool, False) - - new_file_watcher_watch_dog._file_watcher_processes[1] = (mock_process, should_stop, 0) - - new_file_watcher_watch_dog._watch_file_watcher_processes() - - assert not mock_process.is_alive() - assert should_stop.value - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher_watch_dog.Process", return_value=MockProcess()) -def test_watch_file_watcher_processes_dataset_not_in_dataset_ids_in_file_watcher_processes(mock_process, session): - dataset = Dataset( - name="test1", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path="/notexists", - file_watcher_interval=0.1, - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - should_stop = Value(c_bool, False) - - new_file_watcher_watch_dog = NewFileWatcherWatchDog(get_minimal_modyn_config(), should_stop) - - new_file_watcher_watch_dog._watch_file_watcher_processes() - - assert dataset.dataset_id in new_file_watcher_watch_dog._file_watcher_processes - assert new_file_watcher_watch_dog._file_watcher_processes[dataset.dataset_id][0] is not None - assert new_file_watcher_watch_dog._file_watcher_processes[dataset.dataset_id][0].is_alive() - assert not new_file_watcher_watch_dog._file_watcher_processes[dataset.dataset_id][1].value - - -@patch("modyn.storage.internal.file_watcher.new_file_watcher_watch_dog.Process", return_value=MockProcess()) -def test_watch_file_watcher_processes_dataset_in_dataset_ids_in_file_watcher_processes_not_alive(mock_process, session): - dataset = Dataset( - name="test1", - description="test description", - filesystem_wrapper_type=FilesystemWrapperType.LocalFilesystemWrapper, - file_wrapper_type=FileWrapperType.SingleSampleFileWrapper, - base_path="/notexists", - file_watcher_interval=0.1, - last_timestamp=0, - ) - session.add(dataset) - session.commit() - - should_stop = Value(c_bool, False) - - new_file_watcher_watch_dog = NewFileWatcherWatchDog(get_minimal_modyn_config(), should_stop) - - new_file_watcher_watch_dog._file_watcher_processes[dataset.dataset_id] = (mock_process, should_stop, 0) - - mock_process.is_alive.return_value = False - - new_file_watcher_watch_dog._watch_file_watcher_processes() - - assert dataset.dataset_id in new_file_watcher_watch_dog._file_watcher_processes - assert new_file_watcher_watch_dog._file_watcher_processes[dataset.dataset_id][0] is not None - assert new_file_watcher_watch_dog._file_watcher_processes[dataset.dataset_id][0].is_alive() - assert not new_file_watcher_watch_dog._file_watcher_processes[dataset.dataset_id][1].value diff --git a/modyn/tests/storage/internal/file_wrapper/binary_file_wrapper_test.cpp b/modyn/tests/storage/internal/file_wrapper/binary_file_wrapper_test.cpp new file mode 100644 index 000000000..0dea1a09d --- /dev/null +++ b/modyn/tests/storage/internal/file_wrapper/binary_file_wrapper_test.cpp @@ -0,0 +1,235 @@ +#include "internal/file_wrapper/binary_file_wrapper.hpp" + +#include +#include +#include +#include + +#include + +#include "internal/filesystem_wrapper/mock_filesystem_wrapper.hpp" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +class BinaryFileWrapperTest : public ::testing::Test { + protected: + std::string file_name_; + YAML::Node config_; + std::shared_ptr filesystem_wrapper_; + std::string tmp_dir_ = std::filesystem::temp_directory_path().string() + "/binary_file_wrapper_test"; + + BinaryFileWrapperTest() + : config_{StorageTestUtils::get_dummy_file_wrapper_config()}, + filesystem_wrapper_{std::make_shared()} { + file_name_ = tmp_dir_ + "/test.bin"; + } + + void SetUp() override { + std::filesystem::create_directory(tmp_dir_); + + std::ofstream file(file_name_, std::ios::binary); + const std::vector> data = {{42, 12}, {43, 13}, {44, 14}, {45, 15}}; + for (const auto& [payload, label] : data) { + payload_to_file(file, payload, label); + } + file.close(); + } + + static void payload_to_file(std::ofstream& file, uint16_t payload, uint16_t label) { + file.write(reinterpret_cast(&payload), sizeof(uint16_t)); + file.write(reinterpret_cast(&label), sizeof(uint16_t)); + } + + void TearDown() override { std::filesystem::remove_all(file_name_); } +}; + +TEST_F(BinaryFileWrapperTest, TestGetNumberOfSamples) { + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + ASSERT_EQ(file_wrapper.get_number_of_samples(), 4); + + stream_ptr->close(); +} + +TEST_F(BinaryFileWrapperTest, TestValidateFileExtension) { + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + ASSERT_NO_THROW(const BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_);); +} + +TEST_F(BinaryFileWrapperTest, TestValidateRequestIndices) { + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillRepeatedly(testing::Return(stream_ptr)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + std::vector sample = file_wrapper.get_sample(0); + + ASSERT_EQ(sample.size(), 2); + ASSERT_EQ((sample)[0], 12); + + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + BinaryFileWrapper file_wrapper2(file_name_, config_, filesystem_wrapper_); + ASSERT_THROW(file_wrapper2.get_sample(8), modyn::utils::ModynException); +} + +TEST_F(BinaryFileWrapperTest, TestGetLabel) { + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + ASSERT_EQ(file_wrapper.get_label(0), 42); + ASSERT_EQ(file_wrapper.get_label(1), 43); + ASSERT_EQ(file_wrapper.get_label(2), 44); + ASSERT_EQ(file_wrapper.get_label(3), 45); +} + +TEST_F(BinaryFileWrapperTest, TestGetAllLabels) { + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + std::vector labels = file_wrapper.get_all_labels(); + ASSERT_EQ(labels.size(), 4); + ASSERT_EQ((labels)[0], 42); + ASSERT_EQ((labels)[1], 43); + ASSERT_EQ((labels)[2], 44); + ASSERT_EQ((labels)[3], 45); +} + +TEST_F(BinaryFileWrapperTest, TestGetSample) { + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillRepeatedly(testing::Return(16)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillRepeatedly(testing::Return(stream_ptr)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + std::vector sample = file_wrapper.get_sample(0); + ASSERT_EQ(sample.size(), 2); + ASSERT_EQ((sample)[0], 12); + + sample = file_wrapper.get_sample(1); + ASSERT_EQ(sample.size(), 2); + ASSERT_EQ((sample)[0], 13); + + sample = file_wrapper.get_sample(2); + ASSERT_EQ(sample.size(), 2); + ASSERT_EQ((sample)[0], 14); + + sample = file_wrapper.get_sample(3); + ASSERT_EQ(sample.size(), 2); + ASSERT_EQ((sample)[0], 15); +} + +TEST_F(BinaryFileWrapperTest, TestGetSamples) { + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillRepeatedly(testing::Return(16)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillRepeatedly(testing::Return(stream_ptr)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + std::vector> samples = file_wrapper.get_samples(0, 3); + ASSERT_EQ(samples.size(), 4); + ASSERT_EQ((samples)[0][0], 12); + ASSERT_EQ((samples)[1][0], 13); + ASSERT_EQ((samples)[2][0], 14); + ASSERT_EQ((samples)[3][0], 15); + + samples = file_wrapper.get_samples(1, 3); + ASSERT_EQ(samples.size(), 3); + ASSERT_EQ((samples)[0][0], 13); + ASSERT_EQ((samples)[1][0], 14); + ASSERT_EQ((samples)[2][0], 15); + + samples = file_wrapper.get_samples(2, 3); + ASSERT_EQ(samples.size(), 2); + ASSERT_EQ((samples)[0][0], 14); + ASSERT_EQ((samples)[1][0], 15); + + samples = file_wrapper.get_samples(3, 3); + ASSERT_EQ(samples.size(), 1); + ASSERT_EQ((samples)[0][0], 15); + + ASSERT_THROW(file_wrapper.get_samples(4, 3), modyn::utils::ModynException); + + samples = file_wrapper.get_samples(1, 2); + ASSERT_EQ(samples.size(), 2); + ASSERT_EQ((samples)[0][0], 13); + ASSERT_EQ((samples)[1][0], 14); +} + +TEST_F(BinaryFileWrapperTest, TestGetSamplesFromIndices) { + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillRepeatedly(testing::Return(16)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillRepeatedly(testing::Return(stream_ptr)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + std::vector label_indices{0, 1, 2, 3}; + std::vector> samples = file_wrapper.get_samples_from_indices(label_indices); + ASSERT_EQ(samples.size(), 4); + ASSERT_EQ((samples)[0][0], 12); + ASSERT_EQ((samples)[1][0], 13); + ASSERT_EQ((samples)[2][0], 14); + ASSERT_EQ((samples)[3][0], 15); + + label_indices = {1, 2, 3}; + samples = file_wrapper.get_samples_from_indices(label_indices); + ASSERT_EQ(samples.size(), 3); + ASSERT_EQ((samples)[0][0], 13); + ASSERT_EQ((samples)[1][0], 14); + ASSERT_EQ((samples)[2][0], 15); + + label_indices = {2}; + samples = file_wrapper.get_samples_from_indices(label_indices); + ASSERT_EQ(samples.size(), 1); + ASSERT_EQ((samples)[0][0], 14); + + label_indices = {1, 3}; + samples = file_wrapper.get_samples_from_indices(label_indices); + ASSERT_EQ(samples.size(), 2); + ASSERT_EQ((samples)[0][0], 13); + ASSERT_EQ((samples)[1][0], 15); + + label_indices = {3, 1, 3}; + samples = file_wrapper.get_samples_from_indices(label_indices); + ASSERT_EQ(samples.size(), 3); + ASSERT_EQ((samples)[0][0], 15); + ASSERT_EQ((samples)[1][0], 13); + ASSERT_EQ((samples)[2][0], 15); +} + +TEST_F(BinaryFileWrapperTest, TestDeleteSamples) { + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + + const std::vector label_indices{0, 1, 2, 3}; + + ASSERT_NO_THROW(file_wrapper.delete_samples(label_indices)); +} \ No newline at end of file diff --git a/modyn/tests/storage/internal/file_wrapper/csv_file_wrapper_test.cpp b/modyn/tests/storage/internal/file_wrapper/csv_file_wrapper_test.cpp new file mode 100644 index 000000000..a500c3e83 --- /dev/null +++ b/modyn/tests/storage/internal/file_wrapper/csv_file_wrapper_test.cpp @@ -0,0 +1,169 @@ +#include "internal/file_wrapper/csv_file_wrapper.hpp" + +#include +#include + +#include +#include + +#include "gmock/gmock.h" +#include "internal/filesystem_wrapper/mock_filesystem_wrapper.hpp" +#include "modyn/utils/utils.hpp" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +class CsvFileWrapperTest : public ::testing::Test { + protected: + std::string file_name_; + YAML::Node config_; + std::shared_ptr filesystem_wrapper_; + std::string tmp_dir_ = std::filesystem::temp_directory_path().string() + "/csv_file_wrapper_test"; + + CsvFileWrapperTest() + : config_{StorageTestUtils::get_dummy_file_wrapper_config()}, + filesystem_wrapper_{std::make_shared()} { + file_name_ = tmp_dir_ + "/test.csv"; + } + + void SetUp() override { + std::filesystem::create_directory(tmp_dir_); + + std::ofstream file(file_name_); + file << "id,first_name,last_name,age\n"; + file << "1,John,Doe,25\n"; + file << "2,Jane,Smith,30\n"; + file << "3,Michael,Johnson,35\n"; + file.close(); + } + + void TearDown() override { std::filesystem::remove_all(file_name_); } +}; + +TEST_F(CsvFileWrapperTest, TestGetNumberOfSamples) { + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + const uint64_t expected_number_of_samples = 3; + const uint64_t actual_number_of_samples = file_wrapper.get_number_of_samples(); + + ASSERT_EQ(actual_number_of_samples, expected_number_of_samples); +} + +TEST_F(CsvFileWrapperTest, TestGetLabel) { + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + const int64_t index = 1; + const int64_t expected_label = 2; + const int64_t actual_label = file_wrapper.get_label(index); + + ASSERT_EQ(actual_label, expected_label); + + const int64_t invalid_index = 3; + ASSERT_THROW(file_wrapper.get_label(invalid_index), modyn::utils::ModynException); + + const int64_t negative_index = -1; + ASSERT_THROW(file_wrapper.get_label(negative_index), modyn::utils::ModynException); +} + +TEST_F(CsvFileWrapperTest, TestGetAllLabels) { + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + const std::vector expected_labels = {1, 2, 3}; + const std::vector actual_labels = file_wrapper.get_all_labels(); + + ASSERT_EQ(actual_labels, expected_labels); +} + +TEST_F(CsvFileWrapperTest, TestGetSamples) { + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + const int64_t start = 1; + const int64_t end = 3; + const std::vector> expected_samples = { + {'J', 'a', 'n', 'e', ',', 'S', 'm', 'i', 't', 'h', ',', '3', '0'}, + {'M', 'i', 'c', 'h', 'a', 'e', 'l', ',', 'J', 'o', 'h', 'n', 's', 'o', 'n', ',', '3', '5'}, + }; + const std::vector> actual_samples = file_wrapper.get_samples(start, end); + + ASSERT_EQ(actual_samples, expected_samples); +} + +TEST_F(CsvFileWrapperTest, TestGetSample) { + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + const int64_t index = 1; + const std::vector expected_sample = {'J', 'a', 'n', 'e', ',', 'S', 'm', 'i', 't', 'h', ',', '3', '0'}; + const std::vector actual_sample = file_wrapper.get_sample(index); + + ASSERT_EQ(actual_sample, expected_sample); +} + +TEST_F(CsvFileWrapperTest, TestGetSamplesFromIndices) { + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + const std::vector indices = {0, 2}; + const std::vector> expected_samples = { + {'J', 'o', 'h', 'n', ',', 'D', 'o', 'e', ',', '2', '5'}, + {'M', 'i', 'c', 'h', 'a', 'e', 'l', ',', 'J', 'o', 'h', 'n', 's', 'o', 'n', ',', '3', '5'}, + }; + const std::vector> actual_samples = file_wrapper.get_samples_from_indices(indices); + + ASSERT_EQ(actual_samples, expected_samples); +} + +TEST_F(CsvFileWrapperTest, TestDeleteSamples) { + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + const std::vector indices = {0, 1}; + + file_wrapper.delete_samples(indices); + + const std::vector> expected_samples = { + {'M', 'i', 'c', 'h', 'a', 'e', 'l', ',', 'J', 'o', 'h', 'n', 's', 'o', 'n', ',', '3', '5'}, + }; + + std::ifstream file2(file_name_, std::ios::binary); + file2.ignore(std::numeric_limits::max(), '\n'); + file2.ignore(2); + std::vector buffer(std::istreambuf_iterator(file2), {}); + file2.close(); + buffer.pop_back(); + + ASSERT_EQ(buffer, expected_samples[0]); +} diff --git a/modyn/tests/storage/internal/file_wrapper/file_wrapper_utils_test.cpp b/modyn/tests/storage/internal/file_wrapper/file_wrapper_utils_test.cpp new file mode 100644 index 000000000..74f5a9395 --- /dev/null +++ b/modyn/tests/storage/internal/file_wrapper/file_wrapper_utils_test.cpp @@ -0,0 +1,40 @@ +#include "internal/file_wrapper/file_wrapper_utils.hpp" + +#include + +#include "internal/filesystem_wrapper/mock_filesystem_wrapper.hpp" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +TEST(UtilsTest, TestGetFileWrapper) { + YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); // NOLINT + const std::shared_ptr filesystem_wrapper = std::make_shared(); + EXPECT_CALL(*filesystem_wrapper, get_file_size(testing::_)).WillOnce(testing::Return(8)); + EXPECT_CALL(*filesystem_wrapper, exists(testing::_)).WillRepeatedly(testing::Return(true)); + std::unique_ptr file_wrapper1 = + get_file_wrapper("Testpath.txt", FileWrapperType::SINGLE_SAMPLE, config, filesystem_wrapper); + ASSERT_NE(file_wrapper1, nullptr); + ASSERT_EQ(file_wrapper1->get_type(), FileWrapperType::SINGLE_SAMPLE); + + const std::shared_ptr binary_stream_ptr = std::make_shared(); + binary_stream_ptr->open("Testpath.bin", std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper, get_stream(testing::_)).WillOnce(testing::Return(binary_stream_ptr)); + config["file_extension"] = ".bin"; + std::unique_ptr file_wrapper2 = + get_file_wrapper("Testpath.bin", FileWrapperType::BINARY, config, filesystem_wrapper); + ASSERT_NE(file_wrapper2, nullptr); + ASSERT_EQ(file_wrapper2->get_type(), FileWrapperType::BINARY); + + const std::shared_ptr csv_stream_ptr = std::make_shared(); + csv_stream_ptr->open("Testpath.csv", std::ios::binary); + + EXPECT_CALL(*filesystem_wrapper, get_stream(testing::_)).WillOnce(testing::Return(csv_stream_ptr)); + config["file_extension"] = ".csv"; + std::unique_ptr file_wrapper3 = + get_file_wrapper("Testpath.csv", FileWrapperType::CSV, config, filesystem_wrapper); + ASSERT_NE(file_wrapper3, nullptr); + ASSERT_EQ(file_wrapper3->get_type(), FileWrapperType::CSV); +} diff --git a/modyn/tests/storage/internal/file_wrapper/mock_file_wrapper.hpp b/modyn/tests/storage/internal/file_wrapper/mock_file_wrapper.hpp new file mode 100644 index 000000000..0b58a07fb --- /dev/null +++ b/modyn/tests/storage/internal/file_wrapper/mock_file_wrapper.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include + +#include "gmock/gmock.h" +#include "internal/file_wrapper/FileWrapper.hpp" +#include "storage_test_utils.hpp" + +namespace modyn::storage { +class MockFileWrapper : public FileWrapper { + public: + MockFileWrapper(const std::string& path, const YAML::Node& fw_config, std::shared_ptr& fs_wrapper) + : FileWrapper(path, fw_config, fs_wrapper) {} + MOCK_METHOD(int64_t, get_number_of_samples, (), (override)); + MOCK_METHOD(std::vector>*, get_samples, (int64_t start, int64_t end), (override)); + MOCK_METHOD(int64_t, get_label, (int64_t index), (override)); + MOCK_METHOD(std::vector*, get_all_labels, (), (override)); + MOCK_METHOD(std::vector*, get_sample, (int64_t index), (override)); + MOCK_METHOD(std::vector>*, get_samples_from_indices, (std::vector * indices), + (override)); + MOCK_METHOD(FileWrapperType, get_type, (), (override)); + MOCK_METHOD(void, validate_file_extension, (), (override)); + MOCK_METHOD(void, delete_samples, (std::vector * indices), (override)); + MOCK_METHOD(void, set_file_path, (const std::string& path), (override)); + ~MockFileWrapper() override = default; + MockFileWrapper(const MockFileWrapper& other) : FileWrapper(other) {} +} +} // namespace modyn::storage diff --git a/modyn/tests/storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp b/modyn/tests/storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp new file mode 100644 index 000000000..ac3a1bd9a --- /dev/null +++ b/modyn/tests/storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp @@ -0,0 +1,113 @@ +#include "internal/file_wrapper/single_sample_file_wrapper.hpp" + +#include + +#include "internal/filesystem_wrapper/mock_filesystem_wrapper.hpp" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +TEST(SingleSampleFileWrapperTest, TestGetNumberOfSamples) { + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + const std::shared_ptr filesystem_wrapper = std::make_shared(); + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + ASSERT_EQ(file_wrapper.get_number_of_samples(), 1); +} + +TEST(SingleSampleFileWrapperTest, TestGetLabel) { + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + const std::shared_ptr filesystem_wrapper = std::make_shared(); + const std::vector bytes = {'1', '2', '3', '4', '5', '6', '7', '8'}; + EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); + EXPECT_CALL(*filesystem_wrapper, exists(testing::_)).WillOnce(testing::Return(true)); + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + ASSERT_EQ(file_wrapper.get_label(0), 12345678); +} + +TEST(SingleSampleFileWrapperTest, TestGetAllLabels) { + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + const std::vector bytes = {'1', '2', '3', '4', '5', '6', '7', '8'}; + const std::shared_ptr filesystem_wrapper = std::make_shared(); + EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); + EXPECT_CALL(*filesystem_wrapper, exists(testing::_)).WillOnce(testing::Return(true)); + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + const std::vector labels = file_wrapper.get_all_labels(); + ASSERT_EQ(labels.size(), 1); + ASSERT_EQ((labels)[0], 12345678); +} + +TEST(SingleSampleFileWrapperTest, TestGetSamples) { + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + const std::vector bytes = {'1', '2', '3', '4', '5', '6', '7', '8'}; + const std::shared_ptr filesystem_wrapper = std::make_shared(); + EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + const std::vector> samples = file_wrapper.get_samples(0, 1); + ASSERT_EQ(samples.size(), 1); + ASSERT_EQ(samples[0].size(), 8); + ASSERT_EQ((samples)[0][0], '1'); + ASSERT_EQ((samples)[0][1], '2'); + ASSERT_EQ((samples)[0][2], '3'); + ASSERT_EQ((samples)[0][3], '4'); + ASSERT_EQ((samples)[0][4], '5'); + ASSERT_EQ((samples)[0][5], '6'); + ASSERT_EQ((samples)[0][6], '7'); + ASSERT_EQ((samples)[0][7], '8'); +} + +TEST(SingleSampleFileWrapperTest, TestGetSample) { + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + const std::vector bytes = {'1', '2', '3', '4', '5', '6', '7', '8'}; + const std::shared_ptr filesystem_wrapper = std::make_shared(); + EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + const std::vector samples = file_wrapper.get_sample(0); + ASSERT_EQ(samples.size(), 8); + ASSERT_EQ((samples)[0], '1'); + ASSERT_EQ((samples)[1], '2'); + ASSERT_EQ((samples)[2], '3'); + ASSERT_EQ((samples)[3], '4'); + ASSERT_EQ((samples)[4], '5'); + ASSERT_EQ((samples)[5], '6'); + ASSERT_EQ((samples)[6], '7'); + ASSERT_EQ((samples)[7], '8'); +} + +TEST(SingleSampleFileWrapperTest, TestGetSamplesFromIndices) { + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + const std::vector bytes = {'1', '2', '3', '4', '5', '6', '7', '8'}; + const std::shared_ptr filesystem_wrapper = std::make_shared(); + EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + const std::vector indices = {0}; + const std::vector> samples = file_wrapper.get_samples_from_indices(indices); + ASSERT_EQ(samples.size(), 1); + ASSERT_EQ(samples[0].size(), 8); + ASSERT_EQ((samples)[0][0], '1'); + ASSERT_EQ((samples)[0][1], '2'); + ASSERT_EQ((samples)[0][2], '3'); + ASSERT_EQ((samples)[0][3], '4'); + ASSERT_EQ((samples)[0][4], '5'); + ASSERT_EQ((samples)[0][5], '6'); + ASSERT_EQ((samples)[0][6], '7'); + ASSERT_EQ((samples)[0][7], '8'); +} + +TEST(SingleSampleFileWrapperTest, TestDeleteSamples) { + const std::shared_ptr filesystem_wrapper = std::make_shared(); + + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + + const std::vector indices = {0}; + file_wrapper.delete_samples(indices); +} diff --git a/modyn/tests/storage/internal/file_wrapper/test_binary_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_binary_file_wrapper.py deleted file mode 100644 index dc12acbcf..000000000 --- a/modyn/tests/storage/internal/file_wrapper/test_binary_file_wrapper.py +++ /dev/null @@ -1,155 +0,0 @@ -import os -import pathlib -import shutil - -import pytest -from modyn.storage.internal.file_wrapper.binary_file_wrapper import BinaryFileWrapper -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType - -TMP_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn") -FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.bin") -FILE_DATA = b"\x00\x01\x00\x02\x00\x01\x00\x0f\x00\x00\x07\xd0" # [1,2,1,15,0,2000] -INVALID_FILE_EXTENSION_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.txt") -FILE_WRAPPER_CONFIG = { - "record_size": 4, - "label_size": 2, - "byteorder": "big", -} -SMALL_RECORD_SIZE_CONFIG = { - "record_size": 2, - "label_size": 2, - "byteorder": "big", -} -INDIVISIBLE_RECORD_SIZE_CONFIG = { - "record_size": 5, - "label_size": 2, - "byteorder": "big", -} - - -def setup(): - os.makedirs(TMP_DIR, exist_ok=True) - - with open(FILE_PATH, "wb") as file: - file.write(FILE_DATA) - - -def teardown(): - os.remove(FILE_PATH) - shutil.rmtree(TMP_DIR) - - -class MockFileSystemWrapper: - def __init__(self, file_path): - self.file_path = file_path - - def get(self, file_path): - with open(file_path, "rb") as file: - return file.read() - - def get_size(self, path): - return os.path.getsize(path) - - -def test_init(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.file_path == FILE_PATH - assert file_wrapper.file_wrapper_type == FileWrapperType.BinaryFileWrapper - - -def test_init_with_small_record_size_config(): - with pytest.raises(ValueError): - BinaryFileWrapper(FILE_PATH, SMALL_RECORD_SIZE_CONFIG, MockFileSystemWrapper(FILE_PATH)) - - -def test_init_with_invalid_file_extension(): - with pytest.raises(ValueError): - BinaryFileWrapper( - INVALID_FILE_EXTENSION_PATH, - FILE_WRAPPER_CONFIG, - MockFileSystemWrapper(INVALID_FILE_EXTENSION_PATH), - ) - - -def test_init_with_indivisiable_record_size(): - with pytest.raises(ValueError): - BinaryFileWrapper( - FILE_PATH, - INDIVISIBLE_RECORD_SIZE_CONFIG, - MockFileSystemWrapper(FILE_PATH), - ) - - -def test_get_number_of_samples(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.get_number_of_samples() == 3 - - -def test_get_sample(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - sample = file_wrapper.get_sample(0) - assert sample == b"\x00\x02" - - sample = file_wrapper.get_sample(2) - assert sample == b"\x07\xd0" - - -def test_get_sample_with_invalid_index(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_sample(10) - - -def test_get_label(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - label = file_wrapper.get_label(0) - assert label == 1 - - label = file_wrapper.get_label(2) - assert label == 0 - - -def test_get_all_labels(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.get_all_labels() == [1, 1, 0] - - -def test_get_label_with_invalid_index(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_label(10) - - -def test_get_samples(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - samples = file_wrapper.get_samples(0, 1) - assert len(samples) == 1 - assert samples[0] == b"\x00\x02" - - samples = file_wrapper.get_samples(0, 2) - assert len(samples) == 2 - assert samples[0] == b"\x00\x02" - assert samples[1] == b"\x00\x0f" - - -def test_get_samples_with_invalid_index(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_samples(0, 5) - - with pytest.raises(IndexError): - file_wrapper.get_samples(3, 4) - - -def test_get_samples_from_indices(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - samples = file_wrapper.get_samples_from_indices([0, 2]) - assert len(samples) == 2 - assert samples[0] == b"\x00\x02" - assert samples[1] == b"\x07\xd0" - - -def test_get_samples_from_indices_with_invalid_indices(): - file_wrapper = BinaryFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_samples_from_indices([-2, 1]) diff --git a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py deleted file mode 100644 index b345b574e..000000000 --- a/modyn/tests/storage/internal/file_wrapper/test_csv_file_wrapper.py +++ /dev/null @@ -1,278 +0,0 @@ -import os -import pathlib -import shutil - -import pytest -from modyn.storage.internal.file_wrapper.csv_file_wrapper import CsvFileWrapper -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType - -TMP_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn") -FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.csv") -CUSTOM_FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "wrong_test.csv") -FILE_DATA = b"a;b;c;d;12\ne;f;g;h;76" -INVALID_FILE_EXTENSION_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.txt") -FILE_WRAPPER_CONFIG = { - "ignore_first_line": False, - "label_index": 4, - "separator": ";", -} - - -def setup(): - os.makedirs(TMP_DIR, exist_ok=True) - - with open(FILE_PATH, "wb") as file: - file.write(FILE_DATA) - - -def teardown(): - os.remove(FILE_PATH) - shutil.rmtree(TMP_DIR) - - -class MockFileSystemWrapper: - def __init__(self, file_path): - self.file_path = file_path - - def get(self, file_path): - with open(file_path, "rb") as file: - return file.read() - - def get_size(self, path): - return os.path.getsize(path) - - -def test_init(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.file_path == FILE_PATH - assert file_wrapper.file_wrapper_type == FileWrapperType.CsvFileWrapper - assert file_wrapper.encoding == "utf-8" - assert file_wrapper.label_index == 4 - assert not file_wrapper.ignore_first_line - assert file_wrapper.separator == ";" - - -def test_init_with_invalid_file_extension(): - with pytest.raises(ValueError): - CsvFileWrapper( - INVALID_FILE_EXTENSION_PATH, - FILE_WRAPPER_CONFIG, - MockFileSystemWrapper(INVALID_FILE_EXTENSION_PATH), - ) - - -def test_get_number_of_samples(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.get_number_of_samples() == 2 - - # check if the first line is correctly ignored - file_wrapper.ignore_first_line = True - assert file_wrapper.get_number_of_samples() == 1 - - -def test_get_sample(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - sample = file_wrapper.get_sample(0) - assert sample == b"a;b;c;d" - - sample = file_wrapper.get_sample(1) - assert sample == b"e;f;g;h" - - -def test_get_sample_with_invalid_index(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_sample(10) - - -def test_get_label(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - label = file_wrapper.get_label(0) - assert label == 12 - - label = file_wrapper.get_label(1) - assert label == 76 - - with pytest.raises(IndexError): - file_wrapper.get_label(2) - - -def test_get_all_labels(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.get_all_labels() == [12, 76] - - -def test_get_label_with_invalid_index(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_label(10) - - -def test_get_samples(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - samples = file_wrapper.get_samples(0, 1) - assert len(samples) == 1 - assert samples[0] == b"a;b;c;d" - - samples = file_wrapper.get_samples(0, 2) - assert len(samples) == 2 - assert samples[0] == b"a;b;c;d" - assert samples[1] == b"e;f;g;h" - - -def test_get_samples_with_invalid_index(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_samples(0, 5) - - with pytest.raises(IndexError): - file_wrapper.get_samples(3, 4) - - -def test_get_samples_from_indices_with_invalid_indices(): - file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_samples_from_indices([-2, 1]) - - -def write_to_file(data): - with open(CUSTOM_FILE_PATH, "wb") as file: - file.write(data) - - -def test_invalid_file_content(): - # extra field in one row - wrong_data = b"a;b;c;d;12;e\ne;f;g;h;76" - write_to_file(wrong_data) - - with pytest.raises(ValueError): - _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - # label column outside boundary - wrong_data = b"a;b;c;12\ne;f;g;76" - write_to_file(wrong_data) - - with pytest.raises(ValueError): - _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - # str label column - wrong_data = b"a;b;c;d;e;12\ne;f;g;h;h;76" - write_to_file(wrong_data) - with pytest.raises(ValueError): - _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - # just one str in label - wrong_data = b"a;b;c;d;88;12\ne;f;g;h;h;76" - write_to_file(wrong_data) - with pytest.raises(ValueError): - _ = CsvFileWrapper(CUSTOM_FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - -def test_invalid_file_content_skip_validation(): - # extra field in one row - wrong_data = b"a;b;c;d;12;e\ne;f;g;h;76" - write_to_file(wrong_data) - - config = FILE_WRAPPER_CONFIG.copy() - config["validate_file_content"] = False - - _ = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - # label column outside boundary - wrong_data = b"a;b;c;12\ne;f;g;76" - write_to_file(wrong_data) - - file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - with pytest.raises(IndexError): # fails since index > number of columns - file_wrapper.get_label(1) - - # str label column - wrong_data = b"a;b;c;d;e;12\ne;f;g;h;h;76" - write_to_file(wrong_data) - CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - with pytest.raises(ValueError): # fails to convert to integer - file_wrapper.get_label(1) - - # just one str in label - wrong_data = b"a;b;c;d;88;12\ne;f;g;h;h;76" - write_to_file(wrong_data) - CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - file_wrapper.get_label(0) # does not fail since row 0 is ok - with pytest.raises(ValueError): # fails to convert to integer - file_wrapper.get_label(1) - - -def test_different_separator(): - tsv_file_data = b"a\tb\tc\td\t12\ne\tf\tg\th\t76" - - tsv_file_wrapper_config = { - "ignore_first_line": False, - "label_index": 4, - "separator": "\t", - } - - write_to_file(tsv_file_data) - tsv_file_wrapper = CsvFileWrapper( - CUSTOM_FILE_PATH, tsv_file_wrapper_config, MockFileSystemWrapper(CUSTOM_FILE_PATH) - ) - csv_file_wrapper = CsvFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - - assert tsv_file_wrapper.get_number_of_samples() == csv_file_wrapper.get_number_of_samples() - - assert tsv_file_wrapper.get_sample(0) == b"a\tb\tc\td" - assert tsv_file_wrapper.get_sample(1) == b"e\tf\tg\th" - - tsv_samples = tsv_file_wrapper.get_samples(0, 2) - csv_samples = csv_file_wrapper.get_samples(0, 2) - - tsv_samples = [sample.decode("utf-8").split("\t") for sample in tsv_samples] - csv_samples = [sample.decode("utf-8").split(";") for sample in csv_samples] - assert tsv_samples == csv_samples - - assert tsv_file_wrapper.get_label(0) == csv_file_wrapper.get_label(0) - assert tsv_file_wrapper.get_label(1) == csv_file_wrapper.get_label(1) - - -def test_out_of_order_sequence(): - content = b"A1;B1;C1;1\nA2;B2;C2;2\nA3;B3;C3;3\nA4;B4;C4;4\nA5;B5;C5;5" - converted = [b"A1;B1;C1", b"A2;B2;C2", b"A3;B3;C3", b"A4;B4;C4", b"A5;B5;C5"] - write_to_file(content) - config = { - "ignore_first_line": False, - "label_index": 3, - "separator": ";", - } - file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - # samples - assert file_wrapper.get_samples_from_indices([2, 1]) == [converted[2], converted[1]] - assert file_wrapper.get_samples_from_indices([3, 2, 1]) == [converted[3], converted[2], converted[1]] - assert file_wrapper.get_samples_from_indices([3, 2, 4, 1]) == [ - converted[3], - converted[2], - converted[4], - converted[1], - ] - - -def test_duplicate_request(): - content = b"A1;B1;C1;1\nA2;B2;C2;2\nA3;B3;C3;3\nA4;B4;C4;4\nA5;B5;C5;5" - write_to_file(content) - config = { - "ignore_first_line": False, - "label_index": 3, - "separator": ";", - } - file_wrapper = CsvFileWrapper(CUSTOM_FILE_PATH, config, MockFileSystemWrapper(CUSTOM_FILE_PATH)) - - with pytest.raises(AssertionError): - file_wrapper.get_samples_from_indices([1, 1]) - - with pytest.raises(AssertionError): - file_wrapper.get_samples_from_indices([1, 1, 3]) - - with pytest.raises(AssertionError): - file_wrapper.get_samples_from_indices([1, 1, 13]) diff --git a/modyn/tests/storage/internal/file_wrapper/test_file_wrapper_type.py b/modyn/tests/storage/internal/file_wrapper/test_file_wrapper_type.py deleted file mode 100644 index 784073cad..000000000 --- a/modyn/tests/storage/internal/file_wrapper/test_file_wrapper_type.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType - - -def test_invalid_file_wrapper_type(): - with pytest.raises(ValueError): - file_wrapper_type = FileWrapperType("invalid") - assert file_wrapper_type is None diff --git a/modyn/tests/storage/internal/file_wrapper/test_single_sample_file_wrapper.py b/modyn/tests/storage/internal/file_wrapper/test_single_sample_file_wrapper.py deleted file mode 100644 index a943b2ade..000000000 --- a/modyn/tests/storage/internal/file_wrapper/test_single_sample_file_wrapper.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -import pathlib -import shutil - -import pytest -from modyn.storage.internal.file_wrapper.file_wrapper_type import FileWrapperType -from modyn.storage.internal.file_wrapper.single_sample_file_wrapper import SingleSampleFileWrapper - -TMP_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn") -FILE_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.png") -FILE_PATH_2 = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test_2.png") -INVALID_FILE_EXTENSION_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.txt") -METADATA_PATH = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test.json") -METADATA_PATH_2 = str(pathlib.Path(os.path.abspath(__file__)).parent / "test_tmp" / "modyn" / "test_2.json") -FILE_WRAPPER_CONFIG = {"file_extension": ".png", "label_file_extension": ".json"} -FILE_WRAPPER_CONFIG_MIN = {"file_extension": ".png"} - - -def setup(): - os.makedirs(TMP_DIR, exist_ok=True) - with open(FILE_PATH, "w", encoding="utf-8") as file: - file.write("test") - with open(METADATA_PATH, "wb") as file: - file.write("42".encode("utf-8")) - with open(METADATA_PATH_2, "w", encoding="utf-8") as file: - file.write("42") - - -def teardown(): - os.remove(FILE_PATH) - os.remove(METADATA_PATH) - shutil.rmtree(TMP_DIR) - - -class MockFileSystemWrapper: - def __init__(self, file_path): - self.file_path = file_path - - def get(self, file_path): - with open(file_path, "rb") as file: - return file.read() - - -def test_init(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.file_path == FILE_PATH - assert file_wrapper.file_wrapper_type == FileWrapperType.SingleSampleFileWrapper - - -def test_get_number_of_samples(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.get_number_of_samples() == 1 - - -def test_get_number_of_samples_with_invalid_file_extension(): - file_wrapper = SingleSampleFileWrapper( - INVALID_FILE_EXTENSION_PATH, FILE_WRAPPER_CONFIG_MIN, MockFileSystemWrapper(INVALID_FILE_EXTENSION_PATH) - ) - assert file_wrapper.get_number_of_samples() == 0 - - -def test_get_samples(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - samples = file_wrapper.get_samples(0, 1) - assert len(samples) == 1 - assert samples[0].startswith(b"test") - - -def test_get_samples_with_invalid_indices(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_samples(0, 2) - - -def test_get_sample(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - sample = file_wrapper.get_sample(0) - assert sample.startswith(b"test") - - -def test_get_sample_with_invalid_file_extension(): - file_wrapper = SingleSampleFileWrapper( - INVALID_FILE_EXTENSION_PATH, FILE_WRAPPER_CONFIG_MIN, MockFileSystemWrapper(INVALID_FILE_EXTENSION_PATH) - ) - with pytest.raises(ValueError): - file_wrapper.get_sample(0) - - -def test_get_sample_with_invalid_index(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_sample(1) - - -def test_get_label(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - label = file_wrapper.get_label(0) - assert label == 42 - - file_wrapper = SingleSampleFileWrapper(FILE_PATH_2, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH_2)) - label = file_wrapper.get_label(0) - assert label == 42 - - -def test_get_all_labels(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - assert file_wrapper.get_all_labels() == [42] - - -def test_get_label_with_invalid_index(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_label(1) - - -def test_get_label_no_label(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG_MIN, MockFileSystemWrapper(FILE_PATH)) - label = file_wrapper.get_label(0) - assert label is None - - -def test_get_samples_from_indices(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - samples = file_wrapper.get_samples_from_indices([0]) - assert len(samples) == 1 - assert samples[0].startswith(b"test") - - -def test_get_samples_from_indices_with_invalid_indices(): - file_wrapper = SingleSampleFileWrapper(FILE_PATH, FILE_WRAPPER_CONFIG, MockFileSystemWrapper(FILE_PATH)) - with pytest.raises(IndexError): - file_wrapper.get_samples_from_indices([0, 1]) diff --git a/modyn/tests/storage/internal/filesystem_wrapper/filesystem_wrapper_utils_test.cpp b/modyn/tests/storage/internal/filesystem_wrapper/filesystem_wrapper_utils_test.cpp new file mode 100644 index 000000000..859508f5d --- /dev/null +++ b/modyn/tests/storage/internal/filesystem_wrapper/filesystem_wrapper_utils_test.cpp @@ -0,0 +1,13 @@ +#include "internal/filesystem_wrapper/filesystem_wrapper_utils.hpp" + +#include + +#include "storage_test_utils.hpp" + +using namespace modyn::storage; + +TEST(UtilsTest, TestGetFilesystemWrapper) { + const std::shared_ptr filesystem_wrapper = get_filesystem_wrapper(FilesystemWrapperType::LOCAL); + ASSERT_NE(filesystem_wrapper, nullptr); + ASSERT_EQ(filesystem_wrapper->get_type(), FilesystemWrapperType::LOCAL); +} \ No newline at end of file diff --git a/modyn/tests/storage/internal/filesystem_wrapper/local_filesystem_wrapper_test.cpp b/modyn/tests/storage/internal/filesystem_wrapper/local_filesystem_wrapper_test.cpp new file mode 100644 index 000000000..be461df27 --- /dev/null +++ b/modyn/tests/storage/internal/filesystem_wrapper/local_filesystem_wrapper_test.cpp @@ -0,0 +1,149 @@ +#include "internal/filesystem_wrapper/local_filesystem_wrapper.hpp" + +#include +#include +#include + +#include +#include +#include + +#include "gmock/gmock.h" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; + +const char path_seperator = '/'; + +const std::string current_dir = std::filesystem::current_path(); // NOLINT cert-err58-cpp +const std::string test_base_dir = current_dir + path_seperator + "test_dir"; // NOLINT cert-err58-cpp + +class LocalFilesystemWrapperTest : public ::testing::Test { + protected: + void SetUp() override { + const std::string test_dir = current_dir + path_seperator + "test_dir"; + std::filesystem::create_directory(test_dir); + + const std::string test_dir_2 = test_dir + path_seperator + "test_dir_2"; + std::filesystem::create_directory(test_dir_2); + + const std::string test_file = test_dir + path_seperator + "test_file.txt"; + std::ofstream file(test_file, std::ios::binary); + file << "12345678"; + file.close(); + + const time_t zero_time = 0; + utimbuf ub = {}; + ub.modtime = zero_time; + + utime(test_file.c_str(), &ub); + + const std::string test_file_2 = test_dir_2 + path_seperator + "test_file_2.txt"; + std::ofstream file_2(test_file_2, std::ios::binary); + file_2 << "12345678"; + file_2.close(); + } + + void TearDown() override { + const std::string current_dir = std::filesystem::current_path(); + + const std::string test_dir = current_dir + path_seperator + "test_dir"; + std::filesystem::remove_all(test_dir); + } +}; + +TEST_F(LocalFilesystemWrapperTest, TestGet) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + LocalFilesystemWrapper filesystem_wrapper = LocalFilesystemWrapper(); + std::vector bytes = filesystem_wrapper.get(file_name); + ASSERT_EQ(bytes.size(), 8); + ASSERT_EQ((bytes)[0], '1'); + ASSERT_EQ((bytes)[1], '2'); + ASSERT_EQ((bytes)[2], '3'); + ASSERT_EQ((bytes)[3], '4'); + ASSERT_EQ((bytes)[4], '5'); + ASSERT_EQ((bytes)[5], '6'); + ASSERT_EQ((bytes)[6], '7'); + ASSERT_EQ((bytes)[7], '8'); +} + +TEST_F(LocalFilesystemWrapperTest, TestExists) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + const std::string file_name_2 = test_base_dir + path_seperator + "test_file_2.txt"; + LocalFilesystemWrapper filesystem_wrapper; + ASSERT_TRUE(filesystem_wrapper.exists(file_name)); + ASSERT_FALSE(filesystem_wrapper.exists(file_name_2)); +} + +TEST_F(LocalFilesystemWrapperTest, TestList) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper; + std::vector files = filesystem_wrapper.list(test_base_dir, /*recursive=*/false, ".txt"); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + ASSERT_EQ(files.size(), 1); + ASSERT_EQ((files)[0], file_name); +} + +TEST_F(LocalFilesystemWrapperTest, TestListRecursive) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper; + std::vector files = filesystem_wrapper.list(test_base_dir, /*recursive=*/true, ".txt"); + ASSERT_EQ(files.size(), 2); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + const std::string file_name_2 = test_base_dir + path_seperator + "test_dir_2/test_file_2.txt"; + ASSERT_TRUE(std::find(files.begin(), files.end(), file_name) != files.end()); + ASSERT_TRUE(std::find(files.begin(), files.end(), file_name_2) != files.end()); +} + +TEST_F(LocalFilesystemWrapperTest, TestIsDirectory) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper = LocalFilesystemWrapper(); + ASSERT_TRUE(filesystem_wrapper.is_directory(test_base_dir)); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + ASSERT_FALSE(filesystem_wrapper.is_directory(file_name)); + ASSERT_TRUE(filesystem_wrapper.is_directory(test_base_dir)); +} + +TEST_F(LocalFilesystemWrapperTest, TestIsFile) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper = LocalFilesystemWrapper(); + ASSERT_FALSE(filesystem_wrapper.is_file(test_base_dir)); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + ASSERT_TRUE(filesystem_wrapper.is_file(file_name)); + ASSERT_FALSE(filesystem_wrapper.is_file(test_base_dir)); +} + +TEST_F(LocalFilesystemWrapperTest, TestGetFileSize) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper = LocalFilesystemWrapper(); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + ASSERT_EQ(filesystem_wrapper.get_file_size(file_name), 8); +} + +TEST_F(LocalFilesystemWrapperTest, TestGetModifiedTime) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper = LocalFilesystemWrapper(); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + ASSERT_EQ(filesystem_wrapper.get_modified_time(file_name), 0); +} + +TEST_F(LocalFilesystemWrapperTest, TestIsValidPath) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper = LocalFilesystemWrapper(); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + ASSERT_TRUE(filesystem_wrapper.is_valid_path(test_base_dir)); + ASSERT_TRUE(filesystem_wrapper.is_valid_path(file_name)); + ASSERT_FALSE(filesystem_wrapper.is_valid_path("invalid_path")); +} + +TEST_F(LocalFilesystemWrapperTest, TestRemove) { + const YAML::Node config = modyn::test::TestUtils::get_dummy_config(); + LocalFilesystemWrapper filesystem_wrapper = LocalFilesystemWrapper(); + const std::string file_name = test_base_dir + path_seperator + "test_file.txt"; + ASSERT_TRUE(filesystem_wrapper.exists(file_name)); + filesystem_wrapper.remove(file_name); + ASSERT_FALSE(filesystem_wrapper.exists(file_name)); +} \ No newline at end of file diff --git a/modyn/tests/storage/internal/filesystem_wrapper/mock_filesystem_wrapper.hpp b/modyn/tests/storage/internal/filesystem_wrapper/mock_filesystem_wrapper.hpp new file mode 100644 index 000000000..242e7b6c2 --- /dev/null +++ b/modyn/tests/storage/internal/filesystem_wrapper/mock_filesystem_wrapper.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include + +#include "gmock/gmock.h" +#include "internal/filesystem_wrapper/filesystem_wrapper.hpp" +#include "storage_test_utils.hpp" + +namespace modyn::storage { +class MockFilesystemWrapper : public FilesystemWrapper { + public: + MockFilesystemWrapper() : FilesystemWrapper() {} // NOLINT + MOCK_METHOD(std::vector, get, (const std::string& path), (override)); + MOCK_METHOD(bool, exists, (const std::string& path), (override)); + MOCK_METHOD(std::vector, list, (const std::string& path, bool recursive, std::string extension), + (override)); + MOCK_METHOD(bool, is_directory, (const std::string& path), (override)); + MOCK_METHOD(bool, is_file, (const std::string& path), (override)); + MOCK_METHOD(uint64_t, get_file_size, (const std::string& path), (override)); + MOCK_METHOD(int64_t, get_modified_time, (const std::string& path), (override)); + MOCK_METHOD(bool, is_valid_path, (const std::string& path), (override)); + MOCK_METHOD(std::shared_ptr, get_stream, (const std::string& path), (override)); + MOCK_METHOD(FilesystemWrapperType, get_type, (), (override)); + MOCK_METHOD(bool, remove, (const std::string& path), (override)); + ~MockFilesystemWrapper() override = default; + MockFilesystemWrapper(const MockFilesystemWrapper&) = delete; + MockFilesystemWrapper& operator=(const MockFilesystemWrapper&) = delete; + MockFilesystemWrapper(MockFilesystemWrapper&&) = delete; + MockFilesystemWrapper& operator=(MockFilesystemWrapper&&) = delete; +}; +} // namespace modyn::storage diff --git a/modyn/tests/storage/internal/filesystem_wrapper/test_filesystem_wrapper_type.py b/modyn/tests/storage/internal/filesystem_wrapper/test_filesystem_wrapper_type.py deleted file mode 100644 index 8555be650..000000000 --- a/modyn/tests/storage/internal/filesystem_wrapper/test_filesystem_wrapper_type.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest -from modyn.storage.internal.filesystem_wrapper.filesystem_wrapper_type import FilesystemWrapperType - - -def test_invalid_filesystem_wrapper_type(): - with pytest.raises(ValueError): - filesystem_wrapper_type = FilesystemWrapperType("invalid") - assert filesystem_wrapper_type is None diff --git a/modyn/tests/storage/internal/filesystem_wrapper/test_local_filesystem_wrapper.py b/modyn/tests/storage/internal/filesystem_wrapper/test_local_filesystem_wrapper.py deleted file mode 100644 index 771900d88..000000000 --- a/modyn/tests/storage/internal/filesystem_wrapper/test_local_filesystem_wrapper.py +++ /dev/null @@ -1,259 +0,0 @@ -import os -import pathlib - -import pytest -from modyn.storage.internal.filesystem_wrapper.local_filesystem_wrapper import LocalFilesystemWrapper - -TEST_DIR = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "modyn" / "test_dir") -TEST_FILE = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "modyn" / "test_dir" / "test_file") -TEST_FILE_MODIFIED_AT = None -TEST_DIR2 = str(pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "modyn" / "test_dir" / "test_dir") -TEST_FILE2 = str( - pathlib.Path(os.path.abspath(__file__)).parent / "tmp" / "modyn" / "test_dir" / "test_dir" / "test_file2" -) -TEST_FILE2_MODIFIED_AT = None - - -def setup(): - os.makedirs(TEST_DIR, exist_ok=True) - - with open(TEST_FILE, "w", encoding="utf8") as file: - file.write("test1") - - global TEST_FILE_MODIFIED_AT #  pylint: disable=global-statement # noqa: E262 - TEST_FILE_MODIFIED_AT = int(os.path.getmtime(TEST_FILE) * 1000) - - os.makedirs(TEST_DIR2, exist_ok=True) - - with open(TEST_FILE2, "w", encoding="utf8") as file: - file.write("test2 long") - - global TEST_FILE2_MODIFIED_AT #  pylint: disable=global-statement # noqa: E262 - TEST_FILE2_MODIFIED_AT = int(os.path.getmtime(TEST_FILE2) * 1000) - - -def teardown(): - os.remove(TEST_FILE) - os.remove(TEST_FILE2) - os.rmdir(TEST_DIR2) - os.rmdir(TEST_DIR) - - -def test_init(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.base_path == TEST_DIR - - -def test_get(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - file = filesystem_wrapper.get(TEST_FILE) - assert file == b"test1" - - file = filesystem_wrapper.get(TEST_FILE2) - assert file == b"test2 long" - - -def test_get_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get("not_found") - - -def test_get_directory(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(IsADirectoryError): - filesystem_wrapper.get(TEST_DIR2) - - -def test_get_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_exists(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.exists(TEST_FILE) - assert filesystem_wrapper.exists(TEST_FILE2) - assert filesystem_wrapper.exists(TEST_DIR) - assert filesystem_wrapper.exists(TEST_DIR2) - - -def test_exists_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.exists("not_found") - - -def test_exists_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.exists(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_list(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert set(filesystem_wrapper.list(TEST_DIR)) == set(["test_file", "test_dir"]) - assert filesystem_wrapper.list(TEST_DIR2) == ["test_file2"] - - -def test_list_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.list("not_found") - - -def test_list_not_directory(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(NotADirectoryError): - filesystem_wrapper.list(TEST_FILE) - - -def test_list_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.list(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_list_recursive(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert set(filesystem_wrapper.list(TEST_DIR, recursive=True)) == set([TEST_FILE, TEST_FILE2]) - - -def test_list_recursive_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.list("not_found", recursive=True) - - -def test_list_recursive_not_directory(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(NotADirectoryError): - filesystem_wrapper.list(TEST_FILE, recursive=True) - - -def test_list_recursive_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.list(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path"), recursive=True) - - -def test_isdir(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.isdir(TEST_DIR) - assert filesystem_wrapper.isdir(TEST_DIR2) - - -def test_isdir_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.isdir("not_found") - - -def test_isdir_not_directory(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.isdir(TEST_FILE) - - -def test_isdir_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.isdir(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_isfile(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.isfile(TEST_FILE) - assert filesystem_wrapper.isfile(TEST_FILE2) - - -def test_isfile_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.isfile("not_found") - - -def test_isfile_not_directory(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.isfile(TEST_DIR) - - -def test_isfile_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert not filesystem_wrapper.isfile(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_get_size(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.get_size(TEST_FILE) == 5 - assert filesystem_wrapper.get_size(TEST_FILE2) == 10 - - -def test_get_size_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get_size("not_found") - - -def test_get_size_not_file(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(IsADirectoryError): - filesystem_wrapper.get_size(TEST_DIR) - - -def test_get_size_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get_size(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_get_modified(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.get_modified(TEST_FILE) == TEST_FILE_MODIFIED_AT - assert filesystem_wrapper.get_modified(TEST_FILE2) == TEST_FILE2_MODIFIED_AT - - -def test_get_modified_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get_modified("not_found") - - -def test_get_modified_not_file(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(IsADirectoryError): - filesystem_wrapper.get_modified(TEST_DIR) - - -def test_get_modified_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get_modified(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_get_created(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.get_created(TEST_FILE) == TEST_FILE_MODIFIED_AT - assert filesystem_wrapper.get_created(TEST_FILE2) == TEST_FILE2_MODIFIED_AT - - -def test_get_created_not_found(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get_created("not_found") - - -def test_get_created_not_file(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(IsADirectoryError): - filesystem_wrapper.get_created(TEST_DIR) - - -def test_get_created_not_in_base_path(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - with pytest.raises(ValueError): - filesystem_wrapper.get_created(os.path.sep + os.path.join("tmp", "modyn", "not_in_base_path")) - - -def test_join(): - filesystem_wrapper = LocalFilesystemWrapper(TEST_DIR) - assert filesystem_wrapper.join("a", "b") == "a" + os.path.sep + "b" - assert filesystem_wrapper.join("a", "b", "c") == "a" + os.path.sep + "b" + os.path.sep + "c" - assert ( - filesystem_wrapper.join("a", "b", "c", "d") == "a" + os.path.sep + "b" + os.path.sep + "c" + os.path.sep + "d" - ) diff --git a/modyn/tests/storage/internal/grpc/storage_service_impl_test.cpp b/modyn/tests/storage/internal/grpc/storage_service_impl_test.cpp new file mode 100644 index 000000000..8bb9b3851 --- /dev/null +++ b/modyn/tests/storage/internal/grpc/storage_service_impl_test.cpp @@ -0,0 +1,766 @@ +#include "internal/grpc/storage_service_impl.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "gmock/gmock.h" +#include "internal/database/storage_database_connection.hpp" +#include "internal/filesystem_wrapper/mock_filesystem_wrapper.hpp" +#include "storage_test_utils.hpp" +#include "test_utils.hpp" + +using namespace modyn::storage; +using namespace grpc; + +class StorageServiceImplTest : public ::testing::Test { + protected: + std::string tmp_dir_; + int64_t early_sample_id_ = -1; + int64_t late_sample_id_ = -1; + + StorageServiceImplTest() : tmp_dir_{std::filesystem::temp_directory_path().string() + "/storage_service_impl_test"} {} + + void SetUp() override { + modyn::test::TestUtils::create_dummy_yaml(); + // Create temporary directory + std::filesystem::create_directory(tmp_dir_); + const YAML::Node config = YAML::LoadFile("config.yaml"); + const StorageDatabaseConnection connection(config); + connection.create_tables(); + + // Add a dataset to the database + connection.add_dataset("test_dataset", tmp_dir_, FilesystemWrapperType::LOCAL, FileWrapperType::SINGLE_SAMPLE, + "test description", "0.0.0", StorageTestUtils::get_dummy_file_wrapper_config_inline(), + /*ignore_last_timestamp=*/true); + + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + std::string sql_expression = fmt::format( + "INSERT INTO files (dataset_id, path, updated_at, number_of_samples) VALUES (1, '{}/test_file.txt', 100, " + "1)", + tmp_dir_); + session << sql_expression; + + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 1, 0, 0)"; + long long inserted_id = -1; // NOLINT google-runtime-int (Linux otherwise complains about the following call) + if (!session.get_last_insert_id("samples", inserted_id)) { + FAIL("Failed to insert sample into database"); + } + late_sample_id_ = static_cast(inserted_id); + + sql_expression = fmt::format( + "INSERT INTO files (dataset_id, path, updated_at, number_of_samples) VALUES (1, '{}/test_file2.txt', " + "1, 1)", + tmp_dir_); + session << sql_expression; + + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 2, 0, 1)"; + inserted_id = -1; // NOLINT google-runtime-int (Linux otherwise complains about the following call) + if (!session.get_last_insert_id("samples", inserted_id)) { + FAIL("Failed to insert sample into database"); + } + early_sample_id_ = static_cast(inserted_id); + + // Create dummy files + const std::string test_file_path = tmp_dir_ + "/test_file.txt"; + std::ofstream test_file(test_file_path); + ASSERT(test_file.is_open(), "Could not open test file"); + test_file << "test"; + test_file.close(); + ASSERT(!test_file.is_open(), "Could not close test file"); + + const std::string label_file_path = tmp_dir_ + "/test_file.lbl"; + std::ofstream label_file(label_file_path); + ASSERT(label_file.is_open(), "Could not open label file"); + label_file << "1"; + label_file.close(); + ASSERT(!label_file.is_open(), "Could not close label file"); + + const std::string test_file_path2 = tmp_dir_ + "/test_file2.txt"; + std::ofstream test_file2(test_file_path2); + ASSERT(test_file2.is_open(), "Could not open test file"); + test_file2 << "test"; + test_file2.close(); + ASSERT(!test_file2.is_open(), "Could not close test file"); + + const std::string label_file_path2 = tmp_dir_ + "/test_file2.lbl"; + std::ofstream label_file2(label_file_path2); + ASSERT(label_file2.is_open(), "Could not open label file"); + label_file2 << "2"; + label_file2.close(); + ASSERT(!label_file2.is_open(), "Could not close label file"); + } + + void TearDown() override { + // Remove temporary directory + std::filesystem::remove_all(tmp_dir_); + std::filesystem::remove("config.yaml"); + if (std::filesystem::exists("test.db")) { + std::filesystem::remove("test.db"); + } + } +}; + +TEST_F(StorageServiceImplTest, TestCheckAvailability) { + ServerContext context; + + modyn::storage::DatasetAvailableRequest request; + request.set_dataset_id("test_dataset"); + + modyn::storage::DatasetAvailableResponse response; + + const YAML::Node config = YAML::LoadFile("config.yaml"); + StorageServiceImpl storage_service(config); + + Status status = storage_service.CheckAvailability(&context, &request, &response); + + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(response.available()); + + request.set_dataset_id("non_existing_dataset"); + status = storage_service.CheckAvailability(&context, &request, &response); + + EXPECT_FALSE(response.available()); +} + +TEST_F(StorageServiceImplTest, TestGetCurrentTimestamp) { + ServerContext context; + + const modyn::storage::GetCurrentTimestampRequest request; + + modyn::storage::GetCurrentTimestampResponse response; + + const YAML::Node config = YAML::LoadFile("config.yaml"); + StorageServiceImpl storage_service(config); + + const Status status = storage_service.GetCurrentTimestamp(&context, &request, &response); + + EXPECT_TRUE(status.ok()); + EXPECT_GE(response.timestamp(), 0); +} + +TEST_F(StorageServiceImplTest, TestDeleteDataset) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + StorageServiceImpl storage_service(config); + + const StorageDatabaseConnection connection(config); + + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + modyn::storage::DatasetAvailableRequest request; + request.set_dataset_id("test_dataset"); + + modyn::storage::DeleteDatasetResponse response; + + ServerContext context; + + int dataset_exists = 0; + session << "SELECT COUNT(*) FROM datasets WHERE name = 'test_dataset'", soci::into(dataset_exists); + + ASSERT_TRUE(dataset_exists); + + const Status status = storage_service.DeleteDataset(&context, &request, &response); + + ASSERT_TRUE(status.ok()); + + ASSERT_TRUE(response.success()); + + dataset_exists = 0; + session << "SELECT COUNT(*) FROM datasets WHERE name = 'test_dataset'", soci::into(dataset_exists); + + ASSERT_FALSE(dataset_exists); +} + +TEST_F(StorageServiceImplTest, TestDeleteData) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + StorageServiceImpl storage_service(config); // NOLINT misc-const-correctness + + modyn::storage::DeleteDataRequest request; + request.set_dataset_id("test_dataset"); + request.add_keys(1); + + // Add an additional sample for file 1 to the database + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 1, 1, 0)"; + + modyn::storage::DeleteDataResponse response; + + ServerContext context; + + Status status = storage_service.DeleteData(&context, &request, &response); + + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(response.success()); + + int number_of_samples = 0; + session << "SELECT COUNT(*) FROM samples WHERE dataset_id = 1", soci::into(number_of_samples); + + ASSERT_EQ(number_of_samples, 2); + + ASSERT_FALSE(std::filesystem::exists(tmp_dir_ + "/test_file.txt")); + + ASSERT_TRUE(std::filesystem::exists(tmp_dir_ + "/test_file2.txt")); + + request.clear_keys(); + + status = storage_service.DeleteData(&context, &request, &response); + + request.add_keys(1); + + status = storage_service.DeleteData(&context, &request, &response); + + request.clear_keys(); + request.add_keys(2); + + status = storage_service.DeleteData(&context, &request, &response); + + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(response.success()); + + number_of_samples = 0; + session << "SELECT COUNT(*) FROM samples WHERE dataset_id = 1", soci::into(number_of_samples); + + ASSERT_EQ(number_of_samples, 1); +} + +// NOLINTNEXTLINE (readability-function-cognitive-complexity) +TEST_F(StorageServiceImplTest, TestGetNewDataSince) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + StorageServiceImpl storage_service(config); // NOLINT misc-const-correctness + grpc::ServerContext context; + grpc::internal::Call call; + modyn::storage::MockServerWriter writer(&call, &context); + + modyn::storage::GetNewDataSinceRequest request; + request.set_dataset_id("test_dataset"); + request.set_timestamp(0); + + grpc::Status status = + storage_service.GetNewDataSince_Impl>( + &context, &request, &writer); + + ASSERT_TRUE(status.ok()); + const std::vector& responses = writer.get_responses(); + ASSERT_EQ(responses.size(), 1); + const modyn::storage::GetNewDataSinceResponse& response = responses[0]; + + std::vector keys; + keys.reserve(response.keys_size()); + for (const auto& key : response.keys()) { + keys.push_back(key); + } + + ASSERT_THAT(keys, ::testing::UnorderedElementsAre(early_sample_id_, late_sample_id_)); + + // Now try only the second file + + modyn::storage::MockServerWriter writer2(&call, &context); + request.set_timestamp(50); + status = + storage_service.GetNewDataSince_Impl>( + &context, &request, &writer2); + ASSERT_TRUE(status.ok()); + const std::vector& responses2 = writer2.get_responses(); + ASSERT_EQ(responses2.size(), 1); + const modyn::storage::GetNewDataSinceResponse& response2 = responses2[0]; + std::vector keys2; + keys2.reserve(response2.keys_size()); + for (const auto& key : response2.keys()) { + keys2.push_back(key); + } + + ASSERT_THAT(keys2, ::testing::ElementsAre(late_sample_id_)); + + // And now no files + modyn::storage::MockServerWriter writer3(&call, &context); + request.set_timestamp(101); + status = + storage_service.GetNewDataSince_Impl>( + &context, &request, &writer3); + ASSERT_TRUE(status.ok()); + const std::vector& responses3 = writer3.get_responses(); + ASSERT_EQ(responses3.size(), 0); +} + +TEST_F(StorageServiceImplTest, TestGetDataInInterval) { // NOLINT(readability-function-cognitive-complexity) + const YAML::Node config = YAML::LoadFile("config.yaml"); + StorageServiceImpl storage_service(config); // NOLINT misc-const-correctness + grpc::ServerContext context; + grpc::internal::Call call; + modyn::storage::MockServerWriter writer(&call, &context); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + const std::string sql_expression = fmt::format( + "INSERT INTO files (dataset_id, path, updated_at, number_of_samples) VALUES (1, '{}/non_existing.txt', 200, " + "1)", + tmp_dir_); + session << sql_expression; + + long long inserted_file_id = -1; // NOLINT google-runtime-int (soci needs ll) + if (!session.get_last_insert_id("files", inserted_file_id)) { + FAIL("Failed to insert file into database"); + } + + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, :file, 0, 0)", + soci::use(inserted_file_id); + long long inserted_sample_id_ll = // NOLINT google-runtime-int (soci needs ll) + -1; + if (!session.get_last_insert_id("samples", inserted_sample_id_ll)) { + FAIL("Failed to insert sample into database"); + } + + auto inserted_sample_id = static_cast(inserted_sample_id_ll); + + modyn::storage::GetDataInIntervalRequest request; + request.set_dataset_id("test_dataset"); + request.set_start_timestamp(0); + request.set_end_timestamp(250); + + grpc::Status status = + storage_service + .GetDataInInterval_Impl>( + &context, &request, &writer); + + ASSERT_TRUE(status.ok()); + const std::vector& responses = writer.get_responses(); + ASSERT_EQ(responses.size(), 1); + const modyn::storage::GetDataInIntervalResponse& response = responses[0]; + + std::vector keys; + keys.reserve(response.keys_size()); + for (const auto& key : response.keys()) { + keys.push_back(key); + } + + ASSERT_THAT(keys, ::testing::UnorderedElementsAre(early_sample_id_, late_sample_id_, inserted_sample_id)); + + // Now try only the last 2 files + + modyn::storage::MockServerWriter writer2(&call, &context); + request.set_start_timestamp(50); + request.set_end_timestamp(250); + + status = storage_service + .GetDataInInterval_Impl>( + &context, &request, &writer2); + ASSERT_TRUE(status.ok()); + const std::vector& responses2 = writer2.get_responses(); + ASSERT_EQ(responses2.size(), 1); + const modyn::storage::GetDataInIntervalResponse& response2 = responses2[0]; + std::vector keys2; + keys2.reserve(response2.keys_size()); + for (const auto& key : response2.keys()) { + keys2.push_back(key); + } + ASSERT_THAT(keys2, ::testing::UnorderedElementsAre(late_sample_id_, inserted_sample_id)); + + // And now no files + modyn::storage::MockServerWriter writer3(&call, &context); + request.set_start_timestamp(101); + request.set_end_timestamp(180); + status = storage_service + .GetDataInInterval_Impl>( + &context, &request, &writer3); + ASSERT_TRUE(status.ok()); + const std::vector& responses3 = writer3.get_responses(); + ASSERT_EQ(responses3.size(), 0); +} + +TEST_F(StorageServiceImplTest, TestDeleteDataErrorHandling) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + StorageServiceImpl storage_service(config); + + modyn::storage::DeleteDataRequest request; + modyn::storage::DeleteDataResponse response; + + ServerContext context; + + // Test case when dataset does not exist + request.set_dataset_id("non_existent_dataset"); + request.add_keys(1); + Status status = storage_service.DeleteData(&context, &request, &response); + ASSERT_FALSE(response.success()); + + // Test case when no samples found for provided keys + request.set_dataset_id("test_dataset"); + request.clear_keys(); + request.add_keys(99999); // Assuming no sample with this key + status = storage_service.DeleteData(&context, &request, &response); + ASSERT_FALSE(response.success()); + + // Test case when no files found for the samples + // Here we create a sample that doesn't link to a file. + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + session + << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 99999, 0, 0)"; // Assuming no file + // with this id + request.clear_keys(); + request.add_keys(0); + status = storage_service.DeleteData(&context, &request, &response); + ASSERT_FALSE(response.success()); +} + +TEST_F(StorageServiceImplTest, TestGetPartitionForWorker) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + std::tuple result; + ASSERT_NO_THROW(result = StorageServiceImpl::get_partition_for_worker(0, 1, 1)); + ASSERT_EQ(std::get<0>(result), 0); + ASSERT_EQ(std::get<1>(result), 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_partition_for_worker(0, 2, 2)); + ASSERT_EQ(std::get<0>(result), 0); + ASSERT_EQ(std::get<1>(result), 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_partition_for_worker(1, 2, 2)); + ASSERT_EQ(std::get<0>(result), 1); + ASSERT_EQ(std::get<1>(result), 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_partition_for_worker(0, 3, 9)); + ASSERT_EQ(std::get<0>(result), 0); + ASSERT_EQ(std::get<1>(result), 3); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_partition_for_worker(1, 3, 9)); + ASSERT_EQ(std::get<0>(result), 3); + ASSERT_EQ(std::get<1>(result), 3); +} + +TEST_F(StorageServiceImplTest, TestGetNumberOfSamplesInFile) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + int64_t result; + ASSERT_NO_THROW(result = StorageServiceImpl::get_number_of_samples_in_file(1, session, 1)); + ASSERT_EQ(result, 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_number_of_samples_in_file(2, session, 1)); + ASSERT_EQ(result, 1); + + const std::string sql_expression = fmt::format( + "INSERT INTO files (dataset_id, path, updated_at, number_of_samples) VALUES (1, '{}/test_file2.txt', " + "100, 10)", + tmp_dir_); + session << sql_expression; + + ASSERT_NO_THROW(result = StorageServiceImpl::get_number_of_samples_in_file(3, session, 1)); + ASSERT_EQ(result, 10); +} + +TEST_F(StorageServiceImplTest, TestGetFileIds) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + std::vector result; + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids(session, 1, 1, 100)); + // File 1 has timestamp 100, file 2 has timestamp 1 + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 2); + ASSERT_EQ(result[1], 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids(session, 1, 1, 1)); + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0], 2); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids(session, 1, 2, 100)); + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0], 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids(session, 1)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 2); + ASSERT_EQ(result[1], 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids(session, 1, 2)); + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0], 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids(session, 1, 1, 100)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 2); + ASSERT_EQ(result[1], 1); +} + +TEST_F(StorageServiceImplTest, TestGetFileCount) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + int64_t result; + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_count(session, 1, 1, 100)); + ASSERT_EQ(result, 2); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_count(session, 1, 1, 1)); + ASSERT_EQ(result, 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_count(session, 1, 2, 100)); + ASSERT_EQ(result, 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_count(session, 1, -1, -1)); + ASSERT_EQ(result, 2); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_count(session, 1, 2, -1)); + ASSERT_EQ(result, 1); +} + +TEST_F(StorageServiceImplTest, TestGetFileIdsGivenNumberOfFiles) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + std::vector result; + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_given_number_of_files(session, 1, 1, 100, 2)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 2); // file 2 has timestamp 1, file 1 has timestamp 100, return result is ordered + ASSERT_EQ(result[1], 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_given_number_of_files(session, 1, 1, 1, 1)); + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0], 2); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_given_number_of_files(session, 1, 2, 100, 1)); + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0], 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_given_number_of_files(session, 1, -1, -1, 2)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 2); + ASSERT_EQ(result[1], 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_given_number_of_files(session, 1, 2, -1, 1)); + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0], 1); +} + +TEST_F(StorageServiceImplTest, TestGetDatasetId) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + int64_t result; + ASSERT_NO_THROW(result = StorageServiceImpl::get_dataset_id(session, "test_dataset")); + ASSERT_EQ(result, 1); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_dataset_id(session, "non_existent_dataset")); + ASSERT_EQ(result, -1); +} + +TEST_F(StorageServiceImplTest, TestGetFileIdsForSamples) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 1, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 2, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 3, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 4, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 4, 0, 1)"; + + std::vector result; + std::vector request_keys = {1, 2, 3}; + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_for_samples(request_keys, 1, session)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 1); + ASSERT_EQ(result[1], 2); + + request_keys = {3, 4, 5, 6}; + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_for_samples(request_keys, 1, session)); + ASSERT_EQ(result.size(), 4); + ASSERT_EQ(result[0], 1); + ASSERT_EQ(result[1], 2); + ASSERT_EQ(result[2], 3); + ASSERT_EQ(result[3], 4); + + request_keys = {3, 4}; + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_for_samples(request_keys, 1, session)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 1); + ASSERT_EQ(result[1], 2); + + request_keys = {1, 2, 3, 4, 5, 6, 7}; + ASSERT_NO_THROW(result = StorageServiceImpl::get_file_ids_for_samples(request_keys, 1, session)); + ASSERT_EQ(result.size(), 4); + ASSERT_EQ(result[0], 1); + ASSERT_EQ(result[1], 2); + ASSERT_EQ(result[2], 3); + ASSERT_EQ(result[3], 4); +} + +TEST_F(StorageServiceImplTest, TestGetFileIdsPerThread) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + + std::vector::const_iterator, std::vector::const_iterator>> iterator_result; + std::vector file_ids = {1, 2, 3, 4, 5}; + ASSERT_NO_THROW(iterator_result = StorageServiceImpl::get_keys_per_thread(file_ids, 1)); + + std::vector> result; + for (const auto& its : iterator_result) { + std::vector thread_result; + for (auto it = its.first; it < its.second; ++it) { + thread_result.push_back(*it); + } + result.push_back(thread_result); + } + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].size(), 5); + ASSERT_EQ(result[0][0], 1); + ASSERT_EQ(result[0][1], 2); + ASSERT_EQ(result[0][2], 3); + ASSERT_EQ(result[0][3], 4); + ASSERT_EQ(result[0][4], 5); + + ASSERT_NO_THROW(iterator_result = StorageServiceImpl::get_keys_per_thread(file_ids, 2)); + result.clear(); + for (const auto& its : iterator_result) { + std::vector thread_result; + for (auto it = its.first; it < its.second; ++it) { + thread_result.push_back(*it); + } + result.push_back(thread_result); + } + + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0].size(), 2); + ASSERT_EQ(result[0][0], 1); + ASSERT_EQ(result[0][1], 2); + ASSERT_EQ(result[1].size(), 3); + ASSERT_EQ(result[1][0], 3); + ASSERT_EQ(result[1][1], 4); + ASSERT_EQ(result[1][2], 5); + + ASSERT_NO_THROW(iterator_result = StorageServiceImpl::get_keys_per_thread(file_ids, 3)); + result.clear(); + for (const auto& its : iterator_result) { + std::vector thread_result; + for (auto it = its.first; it < its.second; ++it) { + thread_result.push_back(*it); + } + result.push_back(thread_result); + } + ASSERT_EQ(result.size(), 3); + ASSERT_EQ(result[0].size(), 1); + ASSERT_EQ(result[0][0], 1); + ASSERT_EQ(result[1].size(), 1); + ASSERT_EQ(result[1][0], 2); + ASSERT_EQ(result[2].size(), 3); + ASSERT_EQ(result[2][0], 3); + ASSERT_EQ(result[2][1], 4); + ASSERT_EQ(result[2][2], 5); + + file_ids = {1}; + ASSERT_NO_THROW(iterator_result = StorageServiceImpl::get_keys_per_thread(file_ids, 1)); + result.clear(); + for (const auto& its : iterator_result) { + std::vector thread_result; + for (auto it = its.first; it < its.second; ++it) { + thread_result.push_back(*it); + } + result.push_back(thread_result); + } + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].size(), 1); + ASSERT_EQ(result[0][0], 1); + + ASSERT_NO_THROW(iterator_result = StorageServiceImpl::get_keys_per_thread(file_ids, 2)); + result.clear(); + for (const auto& its : iterator_result) { + std::vector thread_result; + for (auto it = its.first; it < its.second; ++it) { + thread_result.push_back(*it); + } + result.push_back(thread_result); + } + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0].size(), 1); + ASSERT_EQ(result[0][0], 1); + ASSERT_EQ(result[1].size(), 0); +} + +TEST_F(StorageServiceImplTest, TestGetSamplesCorrespondingToFiles) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 1, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 2, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 3, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 4, 0, 1)"; + session << "INSERT INTO samples (dataset_id, file_id, sample_index, label) VALUES (1, 4, 0, 1)"; + + std::vector result; + const std::vector request_keys = {1, 2, 3, 4, 5, 6, 7}; + ASSERT_NO_THROW(result = StorageServiceImpl::get_samples_corresponding_to_file(1, 1, request_keys, session)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 1); + ASSERT_EQ(result[1], 3); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_samples_corresponding_to_file(2, 1, request_keys, session)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 2); + ASSERT_EQ(result[1], 4); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_samples_corresponding_to_file(3, 1, request_keys, session)); + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0], 5); + + ASSERT_NO_THROW(result = StorageServiceImpl::get_samples_corresponding_to_file(4, 1, request_keys, session)); + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result[0], 6); + ASSERT_EQ(result[1], 7); +} + +TEST_F(StorageServiceImplTest, TestGetDatasetData) { + const YAML::Node config = YAML::LoadFile("config.yaml"); + + const StorageDatabaseConnection connection(config); + soci::session session = + connection.get_session(); // NOLINT misc-const-correctness (the soci::session cannot be const) + + DatasetData result; + std::string dataset_name = "test_dataset"; + ASSERT_NO_THROW(result = StorageServiceImpl::get_dataset_data(session, dataset_name)); + ASSERT_EQ(result.dataset_id, 1); + ASSERT_EQ(result.base_path, tmp_dir_); + ASSERT_EQ(result.filesystem_wrapper_type, FilesystemWrapperType::LOCAL); + ASSERT_EQ(result.file_wrapper_type, FileWrapperType::SINGLE_SAMPLE); + ASSERT_EQ(result.file_wrapper_config, StorageTestUtils::get_dummy_file_wrapper_config_inline()); + + dataset_name = "non_existent_dataset"; + ASSERT_NO_THROW(result = StorageServiceImpl::get_dataset_data(session, dataset_name)); + ASSERT_EQ(result.dataset_id, -1); + ASSERT_EQ(result.base_path, ""); + ASSERT_EQ(result.filesystem_wrapper_type, FilesystemWrapperType::INVALID_FSW); + ASSERT_EQ(result.file_wrapper_type, FileWrapperType::INVALID_FW); +} \ No newline at end of file diff --git a/modyn/tests/storage/internal/grpc/test_grpc_server.py b/modyn/tests/storage/internal/grpc/test_grpc_server.py deleted file mode 100644 index 3eb992702..000000000 --- a/modyn/tests/storage/internal/grpc/test_grpc_server.py +++ /dev/null @@ -1,12 +0,0 @@ -# pylint: disable=unused-argument - -from modyn.storage.internal.grpc.grpc_server import StorageGRPCServer - - -def get_modyn_config(): - return {"storage": {"port": "50051", "type": "grpc", "sample_batch_size": 1024}} - - -def test_init(): - grpc_server = StorageGRPCServer(get_modyn_config()) - assert grpc_server.modyn_config == get_modyn_config() diff --git a/modyn/tests/storage/internal/grpc/test_storage_grpc_servicer.py b/modyn/tests/storage/internal/grpc/test_storage_grpc_servicer.py deleted file mode 100644 index e8c13eb3e..000000000 --- a/modyn/tests/storage/internal/grpc/test_storage_grpc_servicer.py +++ /dev/null @@ -1,469 +0,0 @@ -# pylint: disable=unused-argument, no-name-in-module -import json -import os -import pathlib -from unittest.mock import patch - -import pytest -from modyn.storage.internal.database.models import Dataset, File, Sample -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.file_wrapper.single_sample_file_wrapper import SingleSampleFileWrapper -from modyn.storage.internal.filesystem_wrapper.local_filesystem_wrapper import LocalFilesystemWrapper -from modyn.storage.internal.grpc.generated.storage_pb2 import ( - DatasetAvailableRequest, - DeleteDataRequest, - GetDataInIntervalRequest, - GetDataPerWorkerRequest, - GetDataPerWorkerResponse, - GetDatasetSizeRequest, - GetDatasetSizeResponse, - GetNewDataSinceRequest, - GetRequest, - RegisterNewDatasetRequest, -) -from modyn.storage.internal.grpc.storage_grpc_servicer import StorageGRPCServicer -from modyn.utils import current_time_millis - -TMP_FILE = str(pathlib.Path(os.path.abspath(__file__)).parent / "test.png") -TMP_FILE2 = str(pathlib.Path(os.path.abspath(__file__)).parent / "test2.png") -TMP_FILE3 = str(pathlib.Path(os.path.abspath(__file__)).parent / "test3.png") -DATABASE = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.database" -NOW = current_time_millis() - - -def get_minimal_modyn_config() -> dict: - return { - "storage": { - "filesystem": {"type": "LocalFilesystemWrapper", "base_path": os.path.dirname(TMP_FILE)}, - "sample_batch_size": 1024, - "database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": "0", - "database": f"{DATABASE}", - }, - "new_file_watcher": {"interval": 1}, - "datasets": [ - { - "name": "test", - "base_path": os.path.dirname(TMP_FILE), - "filesystem_wrapper_type": LocalFilesystemWrapper, - "file_wrapper_type": SingleSampleFileWrapper, - "description": "test", - "version": "0.0.1", - "file_wrapper_config": {}, - } - ], - }, - "project": {"name": "test", "version": "0.0.1"}, - "input": {"type": "LOCAL", "path": os.path.dirname(TMP_FILE)}, - "odm": {"type": "LOCAL"}, - } - - -def setup(): - if os.path.exists(DATABASE): - os.remove(DATABASE) - - os.makedirs(os.path.dirname(TMP_FILE), exist_ok=True) - with open(TMP_FILE, "wb") as file: - file.write(b"test") - with open(TMP_FILE2, "wb") as file: - file.write(b"test2") - with open(TMP_FILE3, "wb") as file: - file.write(b"test3") - - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - now = NOW - before_now = now - 1 - - database.create_tables() - - session = database.session - - dataset = Dataset( - name="test", - base_path=os.path.dirname(TMP_FILE), - filesystem_wrapper_type="LocalFilesystemWrapper", - file_wrapper_type="SingleSampleFileWrapper", - description="test", - version="0.0.1", - file_wrapper_config=json.dumps({"file_extension": "png"}), - last_timestamp=now, - ) - - session.add(dataset) - - session.commit() - - file = File(path=TMP_FILE, dataset=dataset, created_at=now, updated_at=now, number_of_samples=2) - - session.add(file) - - file2 = File(path=TMP_FILE2, dataset=dataset, created_at=now, updated_at=now, number_of_samples=2) - - session.add(file2) - - file3 = File(path=TMP_FILE3, dataset=dataset, created_at=before_now, updated_at=before_now, number_of_samples=2) - - session.add(file3) - - session.commit() - - sample = Sample(dataset_id=dataset.dataset_id, file_id=file.file_id, index=0, label=1) - - session.add(sample) - - sample3 = Sample(dataset_id=dataset.dataset_id, file_id=file2.file_id, index=0, label=3) - - session.add(sample3) - - sample5 = Sample(dataset_id=dataset.dataset_id, file_id=file3.file_id, index=0, label=5) - - session.add(sample5) - - session.commit() - - assert ( - sample.sample_id == 1 and sample3.sample_id == 2 and sample5.sample_id == 3 - ), "Inherent assumptions of primary key generation not met" - - -def teardown(): - os.remove(DATABASE) - try: - os.remove(TMP_FILE) - except FileNotFoundError: - pass - try: - os.remove(TMP_FILE2) - except FileNotFoundError: - pass - try: - os.remove(TMP_FILE3) - except FileNotFoundError: - pass - - -def test_init() -> None: - server = StorageGRPCServicer(get_minimal_modyn_config()) - assert server is not None - - -def test_get(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetRequest(dataset_id="test", keys=[0, 1, 2]) - - expected_responses = [([b"test"], [1], [1]), ([b"test2"], [2], [3]), ([b"test3"], [3], [5])] - - for response, expected_response in zip(server.Get(request, None), expected_responses): - assert response is not None - assert response.samples == expected_response[0] - assert response.keys == expected_response[1] - assert response.labels == expected_response[2] - - -def test_get_invalid_dataset(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetRequest(dataset_id="test2", keys=[1, 2, 3]) - - for response in server.Get(request, None): - assert response is not None - assert response.samples == [] - assert response.keys == [] - assert response.labels == [] - - -def test_get_invalid_key(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetRequest(dataset_id="test", keys=[42]) - responses = list(server.Get(request, None)) - assert len(responses) == 1 - response = responses[0] - - assert response is not None - assert response.samples == [] - assert response.keys == [] - assert response.labels == [] - - -def test_get_not_all_keys_found(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetRequest(dataset_id="test", keys=[1, 42]) - - for response in server.Get(request, None): - assert response is not None - assert response.samples == [b"test"] - - -def test_get_no_keys_providesd(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetRequest(dataset_id="test", keys=[]) - - for response in server.Get(request, None): - assert response is not None - assert response.samples == [] - - -def test_get_new_data_since(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetNewDataSinceRequest(dataset_id="test", timestamp=0) - - responses = list(server.GetNewDataSince(request, None)) - assert 1 == len(responses) - response = responses[0] - - assert response is not None - assert response.keys == [3, 1, 2] - assert response.timestamps == [NOW - 1, NOW, NOW] - assert response.labels == [5, 1, 3] - - -def test_get_new_data_since_batched(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - server._sample_batch_size = 1 - - request = GetNewDataSinceRequest(dataset_id="test", timestamp=0) - - responses = list(server.GetNewDataSince(request, None)) - - assert 3 == len(responses) - response1 = responses[0] - response2 = responses[1] - response3 = responses[2] - - assert response1 is not None - assert response1.keys == [3] - assert response1.timestamps == [NOW - 1] - assert response1.labels == [5] - - assert response2 is not None - assert response2.keys == [1] - assert response2.timestamps == [NOW] - assert response2.labels == [1] - - assert response3 is not None - assert response3.keys == [2] - assert response3.timestamps == [NOW] - assert response3.labels == [3] - - -def test_get_new_data_since_invalid_dataset(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetNewDataSinceRequest(dataset_id="test3", timestamp=0) - - responses = list(server.GetNewDataSince(request, None)) - assert len(responses) == 1 - response = responses[0] - assert response is not None - assert response.keys == [] - assert response.timestamps == [] - assert response.labels == [] - - -def test_get_new_data_since_no_new_data(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetNewDataSinceRequest(dataset_id="test", timestamp=NOW + 100000) - - responses = list(server.GetNewDataSince(request, None)) - assert len(responses) == 0 - - -def test_get_data_in_interval(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetDataInIntervalRequest(dataset_id="test", start_timestamp=0, end_timestamp=NOW + 100000) - - responses = list(server.GetDataInInterval(request, None)) - - assert len(responses) == 1 - response = responses[0] - - assert response is not None - assert response.keys == [3, 1, 2] - assert response.timestamps == [NOW - 1, NOW, NOW] - assert response.labels == [5, 1, 3] - - request = GetDataInIntervalRequest(dataset_id="test", start_timestamp=0, end_timestamp=NOW - 1) - - responses = list(server.GetDataInInterval(request, None)) - - assert len(responses) == 1 - response = responses[0] - - assert response is not None - assert response.keys == [3] - assert response.timestamps == [NOW - 1] - assert response.labels == [5] - - request = GetDataInIntervalRequest(dataset_id="test", start_timestamp=0, end_timestamp=10) - - responses = list(server.GetDataInInterval(request, None)) - - assert len(responses) == 0 - - -def test_get_data_in_interval_invalid_dataset(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetDataInIntervalRequest(dataset_id="test2", start_timestamp=0, end_timestamp=NOW + 100000) - - responses = list(server.GetDataInInterval(request, None)) - assert len(responses) == 1 - response = responses[0] - assert response is not None - assert response.keys == [] - assert response.timestamps == [] - assert response.labels == [] - - -def test_get_data_per_worker(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = GetDataPerWorkerRequest(dataset_id="test", worker_id=0, total_workers=2) - response: [GetDataPerWorkerResponse] = list(server.GetDataPerWorker(request, None)) - assert len(response) == 1 - assert response[0].keys == [1, 2] - - request = GetDataPerWorkerRequest(dataset_id="test", worker_id=1, total_workers=2) - response = list(server.GetDataPerWorker(request, None)) - assert len(response) == 1 - assert response[0].keys == [3] - - request = GetDataPerWorkerRequest(dataset_id="test", worker_id=3, total_workers=4) - response: [GetDataPerWorkerResponse] = list(server.GetDataPerWorker(request, None)) - assert len(response) == 0 - - request = GetDataPerWorkerRequest(dataset_id="test", worker_id=0, total_workers=1) - response: [GetDataPerWorkerResponse] = list(server.GetDataPerWorker(request, None)) - assert len(response) == 1 - assert response[0].keys == [1, 2, 3] - - request = GetDataPerWorkerRequest(dataset_id="test", worker_id=2, total_workers=2) - with pytest.raises(ValueError): - list(server.GetDataPerWorker(request, None)) - - -def test_get_dataset_size(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - request = GetDatasetSizeRequest(dataset_id="test") - response: GetDatasetSizeResponse = server.GetDatasetSize(request, None) - - assert response is not None - assert response.success - assert response.num_keys == 3 - - -def test_get_dataset_size_invalid(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - request = GetDatasetSizeRequest(dataset_id="unknown") - response: GetDatasetSizeResponse = server.GetDatasetSize(request, None) - - assert response is not None - assert not response.success - - -def test_check_availability(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = DatasetAvailableRequest(dataset_id="test") - - response = server.CheckAvailability(request, None) - assert response is not None - assert response.available - - -def test_check_availability_invalid_dataset(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = DatasetAvailableRequest(dataset_id="test2") - - response = server.CheckAvailability(request, None) - assert response is not None - assert not response.available - - -def test_register_new_dataset(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = RegisterNewDatasetRequest( - dataset_id="test3", - base_path=os.path.dirname(TMP_FILE), - filesystem_wrapper_type="LocalFilesystemWrapper", - file_wrapper_type="SingleSampleFileWrapper", - description="test", - version="0.0.1", - file_wrapper_config="{}", - ) - - response = server.RegisterNewDataset(request, None) - assert response is not None - assert response.success - - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - session = database.session - - dataset = session.query(Dataset).filter(Dataset.name == "test3").first() - - assert dataset is not None - assert dataset.name == "test3" - assert dataset.base_path == os.path.dirname(TMP_FILE) - assert dataset.description == "test" - assert dataset.version == "0.0.1" - - -@patch("modyn.storage.internal.grpc.storage_grpc_servicer.current_time_millis", return_value=NOW) -def test_get_current_timestamp(mock_current_time_millis): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - response = server.GetCurrentTimestamp(None, None) - assert response is not None - assert response.timestamp == NOW - - -def test_delete_data(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = DeleteDataRequest(dataset_id="test", keys=[1, 2]) - - response = server.DeleteData(request, None) - assert response is not None - assert response.success - - assert not os.path.exists(TMP_FILE) - assert not os.path.exists(TMP_FILE2) - assert os.path.exists(TMP_FILE3) - - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - session = database.session - - files = session.query(File).filter(File.dataset_id == "test").all() - - assert len(files) == 0 - - -def test_delete_dataset(): - server = StorageGRPCServicer(get_minimal_modyn_config()) - - request = DatasetAvailableRequest(dataset_id="test") - - response = server.DeleteDataset(request, None) - assert response is not None - assert response.success - - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - session = database.session - - dataset = session.query(Dataset).filter(Dataset.name == "test").first() - - assert dataset is None diff --git a/modyn/tests/storage/storage_test_utils.cpp b/modyn/tests/storage/storage_test_utils.cpp new file mode 100644 index 000000000..78981bf83 --- /dev/null +++ b/modyn/tests/storage/storage_test_utils.cpp @@ -0,0 +1,32 @@ +#include "storage_test_utils.hpp" + +using namespace modyn::storage; + +YAML::Node StorageTestUtils::get_dummy_file_wrapper_config() { + YAML::Node config; + config["file_extension"] = ".txt"; + config["label_file_extension"] = ".json"; + config["label_size"] = 2; + config["record_size"] = 4; + config["label_index"] = 0; + config["encoding"] = "utf-8"; + config["validate_file_content"] = false; + config["ignore_first_line"] = true; + config["separator"] = ','; + return config; +} + +std::string StorageTestUtils::get_dummy_file_wrapper_config_inline() { + std::string test_config = R"( +file_extension: ".txt" +label_file_extension: ".lbl" +label_size: 1 +record_size: 2 +label_index: 0 +encoding: "utf-8" +validate_file_content: false +ignore_first_line: false +separator: ',' +)"; + return test_config; +} diff --git a/modyn/tests/storage/storage_test_utils.hpp b/modyn/tests/storage/storage_test_utils.hpp new file mode 100644 index 000000000..1dd6bfc04 --- /dev/null +++ b/modyn/tests/storage/storage_test_utils.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +#include + +namespace modyn::storage { + +class StorageTestUtils { + public: + static YAML::Node get_dummy_file_wrapper_config(); + static std::string get_dummy_file_wrapper_config_inline(); +}; + +template +class MockServerWriter : public grpc::ServerWriterInterface { + public: + MockServerWriter() = default; + + MockServerWriter(grpc::internal::Call* call, grpc::ServerContext* ctx) : call_(call), ctx_(ctx) {} + + MOCK_METHOD0_T(SendInitialMetadata, void()); + + bool Write(const T& response, // NOLINT(readability-identifier-naming) + const grpc::WriteOptions /* options */) override { + responses_.push_back(response); + return true; + }; + + // NOLINTNEXTLINE(readability-identifier-naming) + inline bool Write(const T& msg) { return Write(msg, grpc::WriteOptions()); } + + std::vector get_responses() { return responses_; } + + private: + grpc::internal::Call* const call_ = nullptr; + grpc::ServerContext* const ctx_ = nullptr; + template + friend class grpc::internal::ServerStreamingHandler; + + std::vector responses_; +}; + +} // namespace modyn::storage diff --git a/modyn/tests/storage/test_storage.py b/modyn/tests/storage/test_storage.py deleted file mode 100644 index 22bd0b74d..000000000 --- a/modyn/tests/storage/test_storage.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -import pathlib -from unittest.mock import patch - -import pytest -from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.grpc.grpc_server import StorageGRPCServer -from modyn.storage.storage import Storage - -database_path = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.db" -modyn_config = ( - pathlib.Path(os.path.abspath(__file__)).parent.parent.parent / "config" / "examples" / "modyn_config.yaml" -) - - -def get_minimal_modyn_config() -> dict: - return { - "storage": { - "port": "50051", - "hostname": "localhost", - "sample_batch_size": 1024, - "insertion_threads": 8, - "filesystem": {"type": "LocalFilesystemWrapper", "base_path": "/tmp/modyn"}, - "database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": "0", - "database": f"{database_path}", - }, - "new_file_watcher": {"interval": 1}, - "datasets": [ - { - "name": "test", - "base_path": "/tmp/modyn", - "filesystem_wrapper_type": "LocalFilesystemWrapper", - "file_wrapper_type": "SingleSampleFileWrapper", - "description": "test", - "version": "0.0.1", - "file_wrapper_config": {}, - } - ], - }, - "project": {"name": "test", "version": "0.0.1"}, - "input": {"type": "LOCAL", "path": "/tmp/modyn"}, - "metadata_database": { - "drivername": "sqlite", - "username": "", - "password": "", - "host": "", - "port": "0", - "database": f"{database_path}", - }, - "selector": {"hostname": "host", "port": "1337"}, - "trainer_server": {"hostname": "host", "port": "1337"}, - "evaluator": {"hostname": "host", "port": "1337"}, - "model_storage": {"hostname": "host", "port": "1337", "ftp_port": "1337", "models_directory": "test.dir"}, - } - - -def teardown(): - os.remove(database_path) - - -def setup(): - if database_path.exists(): - os.remove(database_path) - os.makedirs(database_path.parent, exist_ok=True) - - -def get_invalid_modyn_config() -> dict: - return {"invalid": "invalid"} - - -class MockGRPCInstance: - def wait_for_termination(self, *args, **kwargs): # pylint: disable=unused-argument - return - - -class MockGRPCServer(StorageGRPCServer): - def __enter__(self): - return MockGRPCInstance() - - def __exit__(self, *args, **kwargs): # pylint: disable=unused-argument - pass - - -def test_storage_init(): - storage = Storage(modyn_config) - assert storage.modyn_config == modyn_config - - -def test_validate_config(): - storage = Storage(modyn_config) - assert storage._validate_config()[0] - - -@patch("modyn.storage.storage.StorageGRPCServer", MockGRPCServer) -def test_run(): - with StorageDatabaseConnection(get_minimal_modyn_config()) as database: - database.create_tables() - storage = Storage(get_minimal_modyn_config()) - storage.run() - - -def test_invalid_config(): - with pytest.raises(ValueError): - Storage(get_invalid_modyn_config()) diff --git a/modyn/tests/storage/test_storage_entrypoint.py b/modyn/tests/storage/test_storage_entrypoint.py deleted file mode 100644 index c2407016d..000000000 --- a/modyn/tests/storage/test_storage_entrypoint.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -This tests that the entry point script for the storage -successfully runs through. This is _not_ the place for an integration test. -""" -import os -import pathlib -from unittest.mock import patch - -from modyn.storage import Storage - -SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) - -EXAMPLE_SYSTEM_CONFIG = SCRIPT_PATH.parent.parent.parent / "config" / "examples" / "modyn_config.yaml" - -NO_FILE = SCRIPT_PATH.parent / "thisshouldnot.exist" - - -def noop_constructor_mock(self, modyn_config: dict) -> None: # pylint: disable=unused-argument - pass - - -def noop_run(self) -> None: # pylint: disable=unused-argument - pass - - -@patch.object(Storage, "__init__", noop_constructor_mock) -@patch.object(Storage, "run", noop_run) -def test_storage_script_runs(script_runner): - ret = script_runner.run("_modyn_storage", str(EXAMPLE_SYSTEM_CONFIG)) - assert ret.success - - -@patch.object(Storage, "__init__", noop_constructor_mock) -def test_storage_script_fails_on_non_existing_system_config(script_runner): - assert not NO_FILE.is_file(), "File that shouldn't exist exists." - ret = script_runner.run("_modyn_storage", str(NO_FILE)) - assert not ret.success diff --git a/modyn/tests/supervisor/internal/triggers/test_timetrigger.py b/modyn/tests/supervisor/internal/triggers/test_timetrigger.py index b23e0f3fe..306759df2 100644 --- a/modyn/tests/supervisor/internal/triggers/test_timetrigger.py +++ b/modyn/tests/supervisor/internal/triggers/test_timetrigger.py @@ -4,7 +4,7 @@ def test_initialization() -> None: trigger = TimeTrigger({"trigger_every": "2s"}) - assert trigger.trigger_every_ms == 2000 + assert trigger.trigger_every_s == 2 assert trigger.next_trigger_at is None @@ -17,7 +17,7 @@ def test_init_fails_if_invalid() -> None: def test_inform() -> None: - trigger = TimeTrigger({"trigger_every": "1s"}) + trigger = TimeTrigger({"trigger_every": "1000s"}) LABEL = 2 # pylint: disable=invalid-name # pylint: disable-next=use-implicit-booleaness-not-comparison assert trigger.inform([]) == [] diff --git a/modyn/tests/utils/test_utils.cpp b/modyn/tests/utils/test_utils.cpp index e69de29bb..bda27ee94 100644 --- a/modyn/tests/utils/test_utils.cpp +++ b/modyn/tests/utils/test_utils.cpp @@ -0,0 +1,34 @@ +#include "test_utils.hpp" + +using namespace modyn::test; + +void TestUtils::create_dummy_yaml() { + std::ofstream out("config.yaml"); + out << "storage:" << '\n'; + out << " port: 50042" << '\n'; + out << " sample_batch_size: 5" << '\n'; + out << " sample_dbinsertion_batchsize: 10" << '\n'; + out << " insertion_threads: 1" << '\n'; + out << " retrieval_threads: 1" << '\n'; + out << " database:" << '\n'; + out << " drivername: sqlite3" << '\n'; + out << " database: test.db" << '\n'; + out << " username: ''" << '\n'; + out << " password: ''" << '\n'; + out << " host: ''" << '\n'; + out << " port: ''" << '\n'; + out.close(); +} + +void TestUtils::delete_dummy_yaml() { (void)std::remove("config.yaml"); } + +YAML::Node TestUtils::get_dummy_config() { + YAML::Node config; + config["storage"]["database"]["drivername"] = "sqlite3"; + config["storage"]["database"]["database"] = "test.db"; + config["storage"]["database"]["username"] = ""; + config["storage"]["database"]["password"] = ""; + config["storage"]["database"]["host"] = ""; + config["storage"]["database"]["port"] = ""; + return config; +} \ No newline at end of file diff --git a/modyn/tests/utils/test_utils.hpp b/modyn/tests/utils/test_utils.hpp index e69de29bb..0f0f9bb8f 100644 --- a/modyn/tests/utils/test_utils.hpp +++ b/modyn/tests/utils/test_utils.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include + +#include + +namespace modyn::test { +class TestUtils { + public: + static void create_dummy_yaml(); + static void delete_dummy_yaml(); + static YAML::Node get_dummy_config(); +}; +} // namespace modyn::test diff --git a/plotting/system/avg_max_med_batch.py b/plotting/system/avg_max_med_batch.py deleted file mode 100644 index c05b1fb2f..000000000 --- a/plotting/system/avg_max_med_batch.py +++ /dev/null @@ -1,90 +0,0 @@ -import glob -import sys - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -from plotting.common.common import * - - -def plot_baravg(pipeline_log, ax, trigger): - data = [] - - bar_labels = dict() - - for pipeline in pipeline_log: - - relevant_data = pipeline["supervisor"]["triggers"][trigger]["trainer_log"]["epochs"][0] - meta_data = pipeline["configuration"]["pipeline_config"]["training"] - - max_fb = relevant_data["MaxFetchBatch"] / 1000 - avg_fb = relevant_data["AvgFetchBatch"] / 1000 - - total_fb = relevant_data["TotalFetchBatch"] / 1000 - total_train = pipeline["supervisor"]["triggers"][trigger]["trainer_log"]["total_train"] / 1000 - - x = f"{meta_data['dataloader_workers']}/{meta_data['num_prefetched_partitions']}/{meta_data['parallel_prefetch_requests']}" - - percentage = round((total_fb / total_train) * 100,1) - bar_labels[x] = f"{int(total_fb)} ({percentage}%)\n" - - data.append([x, avg_fb, max_fb]) - - data_df = pd.DataFrame(data, columns=["x", "Avg", "Max"]) - test_data_melted = data_df.melt(id_vars="x", value_name = "time", var_name="measure") - - mask = test_data_melted.measure.isin(['Max']) - scale = test_data_melted[~mask].time.mean()/ test_data_melted[mask].time.mean() - test_data_melted.loc[mask, 'time'] = test_data_melted.loc[mask, 'time']*scale - - sns.barplot(data=test_data_melted, x="x", y="time", hue="measure", ax=ax) - bar_label_list = [bar_labels[x._text] for x in ax.get_xticklabels()] - ax.bar_label(ax.containers[0], labels=bar_label_list, size=11) - - ax.set_xlabel("Workers / Prefetched Partitions / Parallel Requests") - ax.tick_params(axis='x', which='major', labelsize=14) - ax.set_ylabel("Avg") - ax2 = ax.twinx() - - ax2.set_ylim(ax.get_ylim()) - ax2.set_yticklabels(np.round(ax.get_yticks()/scale,1)) - ax2.set_ylabel('Max') - ax.get_legend().set_visible(False) - - #ax.set_xticks(list(x)) - #ax.set_xticklabels([f"{idx + 1}" for idx, _ in enumerate(x)]) - #ax.set_xlabel("Waiting time for next batch (seconds)") - - #ax.set_ylabel("Count") - - ax.set_title("Average and Max Time per Batch") - -def load_all_pipelines(data_path): - all_data = [] - - for filename in glob.iglob(data_path + '/**/*.log', recursive=True): - data = LOAD_DATA(filename) - all_data.append(data) - - return all_data - -if __name__ == '__main__': - # Idee: Selber plot mit TotalTrain und anteil fetch batch an total train - - data_path, plot_dir = INIT(sys.argv) - data = load_all_pipelines(data_path) - fig, ax = plt.subplots(1,1, figsize=DOUBLE_FIG_SIZE) - - plot_baravg(data, ax, "0") - - - HATCH_WIDTH() - FIG_LEGEND(fig) - - Y_GRID(ax) - HIDE_BORDERS(ax) - - plot_path = os.path.join(plot_dir, "avg_max") - SAVE_PLOT(plot_path) - PRINT_PLOT_PATHS() \ No newline at end of file diff --git a/scripts/clang-tidy.sh b/scripts/clang-tidy.sh index dddd999c3..3172d2fc5 100755 --- a/scripts/clang-tidy.sh +++ b/scripts/clang-tidy.sh @@ -13,23 +13,24 @@ function run_build() { echo "Running cmake build..." set -x - # TODO(MaxiBoether): add this when merging into storage PR - #mkdir -p "${BUILD_DIR}" - #cmake -S ${SCRIPT_DIR}/.. -B "${BUILD_DIR}" \ - # -DCMAKE_BUILD_TYPE=Debug \ - # -DCMAKE_UNITY_BUILD=OFF + mkdir -p "${BUILD_DIR}" + cmake -S ${SCRIPT_DIR}/.. -B "${BUILD_DIR}" \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_UNITY_BUILD=OFF \ + -DMODYN_BUILD_STORAGE=ON - #pushd ${BUILD_DIR} - #make -j8 modynstorage-proto - #popd + pushd ${BUILD_DIR} + make -j8 modyn-storage-proto + popd cmake -S ${SCRIPT_DIR}/.. -B "${BUILD_DIR}" \ -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_UNITY_BUILD=ON \ - -DCMAKE_UNITY_BUILD_BATCH_SIZE=0 + -DCMAKE_UNITY_BUILD_BATCH_SIZE=0 \ + -DMODYN_BUILD_STORAGE=ON # Due to the include-based nature of the unity build, clang-tidy will not find this configuration file otherwise: - # ln -fs "${SCRIPT_DIR}"/../modyn/tests/.clang-tidy "${BUILD_DIR}"/modyn/tests/ + ln -fs "${SCRIPT_DIR}"/../modyn/tests/.clang-tidy "${BUILD_DIR}"/modyn/tests/ set +x } @@ -46,12 +47,13 @@ function run_tidy() { echo "Will also automatically fix everything that we can..." fi + # For storage, we explicitly include src and include to avoid matching files in the generated directory, containing auto-generated gRPC headers ${RUN_CLANG_TIDY} -p "${BUILD_DIR}" \ -clang-tidy-binary="${CLANG_TIDY}" \ -config-file="${SCRIPT_DIR}/../.clang-tidy" \ -quiet \ -checks='-bugprone-suspicious-include,-google-global-names-in-headers' \ - -header-filter='(.*modyn/storage/.*)|(.*modyn/common/.*)|(.*modyn/playground/.*)|(.*modyn/selector/.*)|(.*modyn/tests.*)' \ + -header-filter='(.*modyn/storage/src/.*)|(.*modyn/storage/include/.*)|(.*modyn/common/.*)|(.*modyn/playground/.*)|(.*modyn/selector/.*)|(.*modyn/tests.*)' \ ${additional_args} \ "${BUILD_DIR}"/modyn/*/Unity/*.cxx \ "${BUILD_DIR}"/modyn/*/*/Unity/*.cxx \ diff --git a/scripts/run_integrationtests.sh b/scripts/run_integrationtests.sh index ac702575a..f850fc320 100755 --- a/scripts/run_integrationtests.sh +++ b/scripts/run_integrationtests.sh @@ -8,6 +8,14 @@ docker compose down BUILDTYPE=${1:-Release} echo "Using build type ${BUILDTYPE} for integrationtests." +if [[ "$BUILDTYPE" == "Release" ]]; then + DEPBUILDTYPE="Release" +else + # Since Asan/Tsan are not necessarily targets of dependencies, we switch to debug mode in all other cases. + DEPBUILDTYPE="Debug" +fi + +echo "Inferred dependency buildtype ${DEPBUILDTYPE}." # When on Github CI, we use the default postgres config to not go OOM if [[ ! -z "$CI" ]]; then @@ -17,12 +25,26 @@ if [[ ! -z "$CI" ]]; then cp conf/default_postgresql.conf conf/storage_postgresql.conf fi -docker build -t modyndependencies -f docker/Dependencies/Dockerfile . -docker build -t modynbase -f docker/Base/Dockerfile --build-arg MODYN_BUILDTYPE=$BUILDTYPE . +docker build -t modyndependencies -f docker/Dependencies/Dockerfile --build-arg MODYN_BUILDTYPE=$BUILDTYPE --build-arg MODYN_DEP_BUILDTYPE=$DEPBUILDTYPE . +docker build -t modynbase -f docker/Base/Dockerfile . # APEX docker build -t modynapex -f docker/Apex/Dockerfile . docker compose up --build tests --abort-on-container-exit --exit-code-from tests + exitcode=$? +echo "LOGS START" +echo "METADATADB" +docker logs $(docker compose ps -q metadata-db) +echo "STORAGEDB" +docker logs $(docker compose ps -q storage-db) +echo "STORAGE" +docker logs $(docker compose ps -q storage) +echo "SELECTOR" +docker logs $(docker compose ps -q selector) +echo "TRAINERSERVER" +docker logs $(docker compose ps -q trainer_server) +echo "LOGS END" + # Cleanup docker compose down if [[ ! -z "$CI" ]]; then diff --git a/scripts/run_modyn.sh b/scripts/run_modyn.sh index 69dc83255..3f85c1645 100755 --- a/scripts/run_modyn.sh +++ b/scripts/run_modyn.sh @@ -5,7 +5,20 @@ PARENT_DIR=$(realpath ${DIR}/../) pushd $PARENT_DIR docker compose down -docker build -t modyndependencies -f docker/Dependencies/Dockerfile . + +BUILDTYPE=${1:-Release} +echo "Running Modyn with buildtype ${BUILDTYPE}." + +if [[ "$BUILDTYPE" == "Release" ]]; then + DEPBUILDTYPE="Release" +else + # Since Asan/Tsan are not necessarily targets of dependencies, we switch to debug mode in all other cases. + DEPBUILDTYPE="Debug" +fi + +echo "Inferred dependency buildtype ${DEPBUILDTYPE}." + +docker build -t modyndependencies -f docker/Dependencies/Dockerfile --build-arg MODYN_BUILDTYPE=$BUILDTYPE --build-arg MODYN_DEP_BUILDTYPE=$DEPBUILDTYPE . docker build -t modynbase -f docker/Base/Dockerfile . # APEX docker build -t modynapex -f docker/Apex/Dockerfile . docker compose up -d --build supervisor