From 63d5447d08549b163dc97a9504cba70893ec8af1 Mon Sep 17 00:00:00 2001 From: Five Grant <5@fivegrant.com> Date: Wed, 14 Feb 2024 11:53:16 -0600 Subject: [PATCH] Separate operation models into separate files (#57) --- .gitignore | 4 +- pyproject.toml | 5 +- service/api.py | 2 +- service/models.py | 350 ------------------ service/models/__init__.py | 4 + service/models/base.py | 66 ++++ .../convert.py => models/converters.py} | 0 service/models/operations/__init__.py | 3 + service/models/operations/calibrate.py | 83 +++++ .../models/operations/ensemble_simulate.py | 53 +++ service/models/operations/simulate.py | 65 ++++ service/models/response.py | 43 +++ 12 files changed, 325 insertions(+), 353 deletions(-) delete mode 100644 service/models.py create mode 100644 service/models/__init__.py create mode 100644 service/models/base.py rename service/{utils/convert.py => models/converters.py} (100%) create mode 100644 service/models/operations/__init__.py create mode 100644 service/models/operations/calibrate.py create mode 100644 service/models/operations/ensemble_simulate.py create mode 100644 service/models/operations/simulate.py create mode 100644 service/models/response.py diff --git a/.gitignore b/.gitignore index cadc913..52d2c09 100644 --- a/.gitignore +++ b/.gitignore @@ -70,8 +70,10 @@ dist build # MISC +.ruff_cache create.sql .eslintcache .version -*.nix +venv +flake.lock diff --git a/pyproject.toml b/pyproject.toml index 0853221..561326f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyciemss-service" -version = "1.7.0" +version = "2.0.0" description = "PyCIEMSS simulation service to run CIEMSS simulations" authors = ["Powell Fendley", "Five Grant"] readme = "README.md" @@ -54,3 +54,6 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] ignore = ["E501"] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401", "F403"] \ No newline at end of file diff --git a/service/api.py b/service/api.py index 99881ae..a1667db 100644 --- a/service/api.py +++ b/service/api.py @@ -5,7 +5,7 @@ from fastapi import FastAPI, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware -from models import ( +from service.models import ( Status, JobResponse, Calibrate, diff --git a/service/models.py b/service/models.py deleted file mode 100644 index 5881394..0000000 --- a/service/models.py +++ /dev/null @@ -1,350 +0,0 @@ -from __future__ import annotations -import socket # noqa: F401 -import logging # noqa: F401 - -from enum import Enum -from typing import ClassVar, Dict, List, Optional -from pydantic import BaseModel, Field, Extra - -from pika.exceptions import AMQPConnectionError - -# TODO: Do not use Torch in PyCIEMSS Library interface -import torch - - -from utils.convert import convert_to_static_interventions, convert_to_solution_mapping -from utils.rabbitmq import gen_rabbitmq_hook # noqa: F401 -from utils.tds import fetch_dataset, fetch_model, fetch_inferred_parameters - - -class Timespan(BaseModel): - start: float = Field(..., example=0) - end: float = Field(..., example=90) - - -class Status(Enum): - cancelled = "cancelled" - complete = "complete" - error = "error" - queued = "queued" - running = "running" - failed = "failed" - started = "started" - finished = "finished" - - @staticmethod - def from_rq(rq_status): - rq_status_to_tds_status = { - "canceled": "cancelled", - "complete": "complete", - "error": "error", - "queued": "queued", - "running": "running", - "failed": "failed", - "started": "running", - "finished": "complete", - } - return Status(rq_status_to_tds_status[rq_status]) - - -class ModelConfig(BaseModel): - id: str = Field(..., example="cd339570-047d-11ee-be55") - solution_mappings: dict[str, str] = Field( - ..., - example={"Infected": "Cases", "Hospitalizations": "hospitalized_population"}, - ) - weight: float = Field(..., example="cd339570-047d-11ee-be55") - - -class Dataset(BaseModel): - id: str = Field(None, example="cd339570-047d-11ee-be55") - filename: str = Field(None, example="dataset.csv") - mappings: Dict[str, str] = Field( - default_factory=dict, - description=( - "Mappings from the dataset column names to " - "the model names they should be replaced with." - ), - example={"postive_tests": "infected"}, - ) - - -class InterventionObject(BaseModel): - timestep: float - name: str - value: float - - -class InterventionSelection(BaseModel): - timestep: float - name: str - - -class QuantityOfInterest(BaseModel): - function: str - state: str - arg: int # TODO: Make this a list of args? - - -######################### Base operation request ############ -class OperationRequest(BaseModel): - pyciemss_lib_function: ClassVar[str] = "" - engine: str = Field("ciemss", example="ciemss") - user_id: str = Field("not_provided", example="not_provided") - - def gen_pyciemss_args(self, job_id): - raise NotImplementedError("PyCIEMSS cannot handle this operation") - - def run_sciml_operation(self, job_id, julia_context): - raise NotImplementedError("SciML cannot handle this operation") - - # @field_validator("engine") - # def must_be_ciemss(cls, engine_choice): - # if engine_choice != "ciemss": - # raise ValueError("The chosen engine is NOT 'ciemss'") - # return engine_choice - - -######################### `simulate` Operation ############ -class SimulateExtra(BaseModel): - num_samples: int = Field( - 100, description="number of samples for a CIEMSS simulation", example=100 - ) - inferred_parameters: Optional[str] = Field( - None, - description="id from a previous calibration", - example=None, - ) - - -class Simulate(OperationRequest): - pyciemss_lib_function: ClassVar[str] = "sample" - model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56") - timespan: Timespan = Timespan(start=0, end=90) - interventions: List[InterventionObject] = Field( - default_factory=list, example=[{"timestep": 1, "name": "beta", "value": 0.4}] - ) - step_size: float = 1.0 - extra: SimulateExtra = Field( - None, - description="optional extra system specific arguments for advanced use cases", - ) - - def gen_pyciemss_args(self, job_id): - # Get model from TDS - amr_path = fetch_model(self.model_config_id, job_id) - - interventions = convert_to_static_interventions(self.interventions) - - extra_options = self.extra.dict() - inferred_parameters = fetch_inferred_parameters( - extra_options.pop("inferred_parameters"), job_id - ) - - return { - "model_path_or_json": amr_path, - "logging_step_size": self.step_size, - "start_time": self.timespan.start, - "end_time": self.timespan.end, - "static_parameter_interventions": interventions, - "inferred_parameters": inferred_parameters, - **extra_options, - } - - def run_sciml_operation(self, job_id, julia_context): - amr_path = fetch_model(self.model_config_id, job_id) - with open(amr_path, "r") as file: - amr = file.read() - result = julia_context.simulate(amr, self.timespan.start, self.timespan.end) - return {"data": julia_context.pytable(result)} - - class Config: - extra = Extra.forbid - - -######################### `calibrate` Operation ############ -class CalibrateExtra(BaseModel): - # start_state: Optional[dict[str,float]] - # pseudocount: float = Field( - # 1.0, description="Optional field for CIEMSS calibration", example=1.0 - # ) - start_time: float = Field( - -1e-10, description="Optional field for CIEMSS calibration", example=-1e-10 - ) - num_iterations: int = Field( - 1000, description="Optional field for CIEMSS calibration", example=1000 - ) - lr: float = Field( - 0.03, description="Optional field for CIEMSS calibration", example=0.03 - ) - verbose: bool = Field( - False, description="Optional field for CIEMSS calibration", example=False - ) - num_particles: int = Field( - 1, description="Optional field for CIEMSS calibration", example=1 - ) - # autoguide: pyro.infer.autoguide.AutoLowRankMultivariateNormal - solver_method: str = Field( - "dopri5", description="Optional field for CIEMSS calibration", example="dopri5" - ) - - -class Calibrate(OperationRequest): - pyciemss_lib_function: ClassVar[str] = "calibrate" - model_config_id: str = Field(..., example="c1cd941a-047d-11ee-be56") - dataset: Dataset = None - timespan: Optional[Timespan] = None - extra: CalibrateExtra = Field( - None, - description="optional extra system specific arguments for advanced use cases", - ) - - def gen_pyciemss_args(self, job_id): - amr_path = fetch_model(self.model_config_id, job_id) - - dataset_path = fetch_dataset(self.dataset.dict(), job_id) - - # TODO: Test RabbitMQ - try: - hook = gen_rabbitmq_hook(job_id) - except (socket.gaierror, AMQPConnectionError): - logging.warning( - "%s: Failed to connect to RabbitMQ. Unable to log progress", job_id - ) - - def hook(progress, _loss): - progress = progress / 10 # TODO: Fix magnitude of progress upstream - if progress == int(progress): - logging.info(f"Calibration is {progress}% complete") - return None - - return { - "model_path_or_json": amr_path, - "start_time": self.timespan.start, - # TODO: Is this intentionally missing from `calibrate`? - # "end_time": self.timespan.end, - "data_path": dataset_path, - "progress_hook": hook, - # "visual_options": True, - **self.extra.dict(), - } - - class Config: - extra = Extra.forbid - - -######################### `ensemble-simulate` Operation ############ -class EnsembleSimulateExtra(BaseModel): - num_samples: int = Field( - 100, description="number of samples for a CIEMSS simulation", example=100 - ) - - -class EnsembleSimulate(OperationRequest): - pyciemss_lib_function: ClassVar[str] = "ensemble_sample" - model_configs: List[ModelConfig] = Field( - [], - example=[], - ) - timespan: Timespan - - step_size: float = 1.0 - - extra: EnsembleSimulateExtra = Field( - None, - description="optional extra system specific arguments for advanced use cases", - ) - - def gen_pyciemss_args(self, job_id): - weights = torch.tensor([config.weight for config in self.model_configs]) - solution_mappings = [ - convert_to_solution_mapping(config) for config in self.model_configs - ] - amr_paths = [fetch_model(config.id, job_id) for config in self.model_configs] - - return { - "model_paths_or_jsons": amr_paths, - "solution_mappings": solution_mappings, - "start_time": self.timespan.start, - "end_time": self.timespan.end, - "logging_step_size": self.step_size, - "dirichlet_alpha": weights, - # "visual_options": True, - **self.extra.dict(), - } - - class Config: - extra = Extra.forbid - - -######################### `ensemble-calibrate` Operation ############ -# class EnsembleCalibrateExtra(BaseModel): -# num_samples: int = Field( -# 100, description="number of samples for a CIEMSS simulation", example=100 -# ) - -# total_population: int = Field(1000, description="total population", example=1000) - -# num_iterations: int = Field(350, description="number of iterations", example=1000) - -# time_unit: int = Field( -# "days", description="units in numbers of days", example="days" -# ) - - -# class EnsembleCalibrate(OperationRequest): -# pyciemss_lib_function: ClassVar[ -# str -# ] = "load_and_calibrate_and_sample_ensemble_model" -# user_id: str = Field("not_provided", example="not_provided") -# model_configs: List[ModelConfig] = Field( -# [], -# example=[], -# ) -# timespan: Timespan = Timespan(start=0, end=90) -# dataset: Dataset -# extra: EnsembleCalibrateExtra = Field( -# None, -# description="optional extra system specific arguments for advanced use cases", -# ) - -# def gen_pyciemss_args(self, job_id): -# weights = [config.weight for config in self.model_configs] -# solution_mappings = [config.solution_mappings for config -# in self.model_configs] -# amr_paths = [ -# fetch_model(config.id, job_id) -# for config in self.model_configs -# ] - -# dataset_path = fetch_dataset(self.dataset.dict(), job_id) - -# # Generate timepoints -# time_count = self.timespan.end - self.timespan.start -# timepoints = [step for step in range(1, time_count + 1)] - -# return { -# "petri_model_or_paths": amr_paths, -# "weights": weights, -# "solution_mappings": solution_mappings, -# "timepoints": timepoints, -# "data_path": dataset_path, -# "visual_options": True, -# **self.extra.dict(), -# } - -# class Config: -# extra = Extra.forbid - - -######################### API Response ############ -class JobResponse(BaseModel): - simulation_id: Optional[str] = Field( - None, - description="Simulation created successfully", - example="fc5d80e4-0483-11ee-be56", - ) - - -class StatusSimulationIdGetResponse(BaseModel): - status: Optional[Status] = None diff --git a/service/models/__init__.py b/service/models/__init__.py new file mode 100644 index 0000000..60c65f5 --- /dev/null +++ b/service/models/__init__.py @@ -0,0 +1,4 @@ +import models.base +import models.converters +from models.operations import * +from models.response import * diff --git a/service/models/base.py b/service/models/base.py new file mode 100644 index 0000000..988e14c --- /dev/null +++ b/service/models/base.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import ClassVar, Dict +from pydantic import BaseModel, Field + + +class Timespan(BaseModel): + start: float = Field(..., example=0) + end: float = Field(..., example=90) + + +class ModelConfig(BaseModel): + id: str = Field(..., example="cd339570-047d-11ee-be55") + solution_mappings: dict[str, str] = Field( + ..., + example={"Infected": "Cases", "Hospitalizations": "hospitalized_population"}, + ) + weight: float = Field(..., example="cd339570-047d-11ee-be55") + + +class Dataset(BaseModel): + id: str = Field(None, example="cd339570-047d-11ee-be55") + filename: str = Field(None, example="dataset.csv") + mappings: Dict[str, str] = Field( + default_factory=dict, + description=( + "Mappings from the dataset column names to " + "the model names they should be replaced with." + ), + example={"postive_tests": "infected"}, + ) + + +class InterventionObject(BaseModel): + timestep: float + name: str + value: float + + +class InterventionSelection(BaseModel): + timestep: float + name: str + + +class QuantityOfInterest(BaseModel): + function: str + state: str + arg: int # TODO: Make this a list of args? + + +class OperationRequest(BaseModel): + pyciemss_lib_function: ClassVar[str] = "" + engine: str = Field("ciemss", example="ciemss") + user_id: str = Field("not_provided", example="not_provided") + + def gen_pyciemss_args(self, job_id): + raise NotImplementedError("PyCIEMSS cannot handle this operation") + + def run_sciml_operation(self, job_id, julia_context): + raise NotImplementedError("SciML cannot handle this operation") + + # @field_validator("engine") + # def must_be_ciemss(cls, engine_choice): + # if engine_choice != "ciemss": + # raise ValueError("The chosen engine is NOT 'ciemss'") + # return engine_choice diff --git a/service/utils/convert.py b/service/models/converters.py similarity index 100% rename from service/utils/convert.py rename to service/models/converters.py diff --git a/service/models/operations/__init__.py b/service/models/operations/__init__.py new file mode 100644 index 0000000..c5c4867 --- /dev/null +++ b/service/models/operations/__init__.py @@ -0,0 +1,3 @@ +from models.operations.simulate import Simulate +from models.operations.calibrate import Calibrate +from models.operations.ensemble_simulate import EnsembleSimulate diff --git a/service/models/operations/calibrate.py b/service/models/operations/calibrate.py new file mode 100644 index 0000000..7ce4412 --- /dev/null +++ b/service/models/operations/calibrate.py @@ -0,0 +1,83 @@ +from __future__ import annotations +import socket +import logging + +from typing import ClassVar, Optional +from pydantic import BaseModel, Field, Extra + +from pika.exceptions import AMQPConnectionError + + +from models.base import Dataset, OperationRequest, Timespan +from utils.rabbitmq import gen_rabbitmq_hook +from utils.tds import fetch_dataset, fetch_model + + +class CalibrateExtra(BaseModel): + # start_state: Optional[dict[str,float]] + # pseudocount: float = Field( + # 1.0, description="Optional field for CIEMSS calibration", example=1.0 + # ) + start_time: float = Field( + -1e-10, description="Optional field for CIEMSS calibration", example=-1e-10 + ) + num_iterations: int = Field( + 1000, description="Optional field for CIEMSS calibration", example=1000 + ) + lr: float = Field( + 0.03, description="Optional field for CIEMSS calibration", example=0.03 + ) + verbose: bool = Field( + False, description="Optional field for CIEMSS calibration", example=False + ) + num_particles: int = Field( + 1, description="Optional field for CIEMSS calibration", example=1 + ) + # autoguide: pyro.infer.autoguide.AutoLowRankMultivariateNormal + solver_method: str = Field( + "dopri5", description="Optional field for CIEMSS calibration", example="dopri5" + ) + + +class Calibrate(OperationRequest): + pyciemss_lib_function: ClassVar[str] = "calibrate" + model_config_id: str = Field(..., example="c1cd941a-047d-11ee-be56") + dataset: Dataset = None + timespan: Optional[Timespan] = None + extra: CalibrateExtra = Field( + None, + description="optional extra system specific arguments for advanced use cases", + ) + + def gen_pyciemss_args(self, job_id): + amr_path = fetch_model(self.model_config_id, job_id) + + dataset_path = fetch_dataset(self.dataset.dict(), job_id) + + # TODO: Test RabbitMQ + try: + hook = gen_rabbitmq_hook(job_id) + except (socket.gaierror, AMQPConnectionError): + logging.warning( + "%s: Failed to connect to RabbitMQ. Unable to log progress", job_id + ) + + def hook(progress, _loss): + progress = progress / 10 # TODO: Fix magnitude of progress upstream + if progress == int(progress): + logging.info(f"Calibration is {progress}% complete") + return None + + return { + "model_path_or_json": amr_path, + "start_time": self.timespan.start, + # TODO: Is this intentionally missing from `calibrate`? + # "end_time": self.timespan.end, + "data_path": dataset_path, + "progress_hook": hook, + # "visual_options": True, + **self.extra.dict(), + } + + class Config: + extra = Extra.forbid diff --git a/service/models/operations/ensemble_simulate.py b/service/models/operations/ensemble_simulate.py new file mode 100644 index 0000000..bb68157 --- /dev/null +++ b/service/models/operations/ensemble_simulate.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import ClassVar, List + +from pydantic import BaseModel, Field, Extra +import torch # TODO: Do not use Torch in PyCIEMSS Library interface + +from models.base import OperationRequest, Timespan, ModelConfig +from models.converters import convert_to_solution_mapping +from utils.tds import fetch_model + + +class EnsembleSimulateExtra(BaseModel): + num_samples: int = Field( + 100, description="number of samples for a CIEMSS simulation", example=100 + ) + + +class EnsembleSimulate(OperationRequest): + pyciemss_lib_function: ClassVar[str] = "ensemble_sample" + model_configs: List[ModelConfig] = Field( + [], + example=[], + ) + timespan: Timespan + + step_size: float = 1.0 + + extra: EnsembleSimulateExtra = Field( + None, + description="optional extra system specific arguments for advanced use cases", + ) + + def gen_pyciemss_args(self, job_id): + weights = torch.tensor([config.weight for config in self.model_configs]) + solution_mappings = [ + convert_to_solution_mapping(config) for config in self.model_configs + ] + amr_paths = [fetch_model(config.id, job_id) for config in self.model_configs] + + return { + "model_paths_or_jsons": amr_paths, + "solution_mappings": solution_mappings, + "start_time": self.timespan.start, + "end_time": self.timespan.end, + "logging_step_size": self.step_size, + "dirichlet_alpha": weights, + # "visual_options": True, + **self.extra.dict(), + } + + class Config: + extra = Extra.forbid diff --git a/service/models/operations/simulate.py b/service/models/operations/simulate.py new file mode 100644 index 0000000..793a02c --- /dev/null +++ b/service/models/operations/simulate.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import ClassVar, List, Optional +from pydantic import BaseModel, Field, Extra + + +from models.base import OperationRequest, Timespan, InterventionObject +from models.converters import convert_to_static_interventions +from utils.tds import fetch_model, fetch_inferred_parameters + + +class SimulateExtra(BaseModel): + num_samples: int = Field( + 100, description="number of samples for a CIEMSS simulation", example=100 + ) + inferred_parameters: Optional[str] = Field( + None, + description="id from a previous calibration", + example=None, + ) + + +class Simulate(OperationRequest): + pyciemss_lib_function: ClassVar[str] = "sample" + model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56") + timespan: Timespan = Timespan(start=0, end=90) + interventions: List[InterventionObject] = Field( + default_factory=list, example=[{"timestep": 1, "name": "beta", "value": 0.4}] + ) + step_size: float = 1.0 + extra: SimulateExtra = Field( + None, + description="optional extra system specific arguments for advanced use cases", + ) + + def gen_pyciemss_args(self, job_id): + # Get model from TDS + amr_path = fetch_model(self.model_config_id, job_id) + + interventions = convert_to_static_interventions(self.interventions) + + extra_options = self.extra.dict() + inferred_parameters = fetch_inferred_parameters( + extra_options.pop("inferred_parameters"), job_id + ) + + return { + "model_path_or_json": amr_path, + "logging_step_size": self.step_size, + "start_time": self.timespan.start, + "end_time": self.timespan.end, + "static_parameter_interventions": interventions, + "inferred_parameters": inferred_parameters, + **extra_options, + } + + def run_sciml_operation(self, job_id, julia_context): + amr_path = fetch_model(self.model_config_id, job_id) + with open(amr_path, "r") as file: + amr = file.read() + result = julia_context.simulate(amr, self.timespan.start, self.timespan.end) + return {"data": julia_context.pytable(result)} + + class Config: + extra = Extra.forbid diff --git a/service/models/response.py b/service/models/response.py new file mode 100644 index 0000000..0b7d69c --- /dev/null +++ b/service/models/response.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Optional +from enum import Enum + +from pydantic import BaseModel, Field + + +class Status(Enum): + cancelled = "cancelled" + complete = "complete" + error = "error" + queued = "queued" + running = "running" + failed = "failed" + started = "started" + finished = "finished" + + @staticmethod + def from_rq(rq_status): + rq_status_to_tds_status = { + "canceled": "cancelled", + "complete": "complete", + "error": "error", + "queued": "queued", + "running": "running", + "failed": "failed", + "started": "running", + "finished": "complete", + } + return Status(rq_status_to_tds_status[rq_status]) + + +class JobResponse(BaseModel): + simulation_id: Optional[str] = Field( + None, + description="Simulation created successfully", + example="fc5d80e4-0483-11ee-be56", + ) + + +class StatusSimulationIdGetResponse(BaseModel): + status: Optional[Status] = None