Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Add middlewares and one specific to expose /metrics endpoint for pr…
Browse files Browse the repository at this point in the history
…ometheus (#629)

did this in spare time instead of
#591

using https://github.com/trallnag/prometheus-fastapi-instrumentator

these middlewares can be enabled via CLI if referenced by their `type`
ClassVar:
```
$ mlem --tb serve fastapi --model ../emoji/lyrics2emoji --middlewares.0 prometheus_fastapi
```

---------

Co-authored-by: mike0sv <[email protected]>
  • Loading branch information
aguschin and mike0sv authored Mar 25, 2023
1 parent 19b1066 commit 4c5be67
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 16 deletions.
33 changes: 30 additions & 3 deletions mlem/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FastAPIServer implementation
"""
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from types import ModuleType
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type
Expand All @@ -24,6 +25,7 @@
InterfaceArgument,
InterfaceMethod,
)
from mlem.runtime.middleware import Middleware, Middlewares
from mlem.runtime.server import Server
from mlem.ui import EMOJI_NAILS, echo
from mlem.utils.module import get_object_requirements
Expand All @@ -48,6 +50,12 @@ def _create_schema_route(app: FastAPI, interface: Interface):
app.add_api_route("/interface.json", lambda: schema, tags=["schema"])


class FastAPIMiddleware(Middleware, ABC):
@abstractmethod
def on_app_init(self, app: FastAPI):
raise NotImplementedError


class FastAPIServer(Server, LibRequirementsMixin):
"""Serves model with http"""

Expand All @@ -70,6 +78,7 @@ def _create_handler_executor(
arg_serializers: Dict[str, DataTypeSerializer],
executor: Callable,
response_serializer: DataTypeSerializer,
middlewares: Middlewares,
):
deserialized_model = create_model(
"Model", **{a: (Any, ...) for a in args}
Expand Down Expand Up @@ -99,7 +108,9 @@ def serializer_validator(_, values):

def bin_handler(model: schema_model): # type: ignore[valid-type]
values = {a: getattr(model, a) for a in args}
values = middlewares.on_request(values)
result = executor(**values)
result = middlewares.on_response(values, result)
with response_serializer.dump(result) as buffer:
return StreamingResponse(
buffer, media_type="application/octet-stream"
Expand All @@ -113,7 +124,9 @@ def bin_handler(model: schema_model): # type: ignore[valid-type]

def handler(model: schema_model): # type: ignore[valid-type]
values = {a: getattr(model, a) for a in args}
values = middlewares.on_request(values)
result = executor(**values)
result = middlewares.on_response(values, result)
response = response_serializer.serialize(result)
return parse_obj_as(response_model, response)

Expand All @@ -127,12 +140,15 @@ def _create_handler_executor_binary(
arg_name: str,
executor: Callable,
response_serializer: DataTypeSerializer,
middlewares: Middlewares,
):
if response_serializer.serializer.is_binary:

def bin_handler(file: UploadFile):
arg = serializer.deserialize(_SpooledFileIOWrapper(file.file))
arg = middlewares.on_request(arg)
result = executor(**{arg_name: arg})
result = middlewares.on_response(arg, result)
with response_serializer.dump(result) as buffer:
return StreamingResponse(
buffer, media_type="application/octet-stream"
Expand All @@ -146,15 +162,20 @@ def bin_handler(file: UploadFile):

def handler(file: UploadFile):
arg = serializer.deserialize(file.file)
arg = middlewares.on_request(arg)
result = executor(**{arg_name: arg})

result = middlewares.on_response(arg, result)
response = response_serializer.serialize(result)
return parse_obj_as(response_model, response)

return handler, response_model, None

def _create_handler(
self, method_name: str, signature: InterfaceMethod, executor: Callable
self,
method_name: str,
signature: InterfaceMethod,
executor: Callable,
middlewares: Middlewares,
) -> Tuple[Optional[Callable], Optional[Type], Optional[Response]]:
serializers, response_serializer = self._get_serializers(signature)
echo(EMOJI_NAILS + f"Adding route for /{method_name}")
Expand All @@ -170,13 +191,15 @@ def _create_handler(
arg_name,
executor,
response_serializer,
middlewares,
)
return self._create_handler_executor(
method_name,
{a.name: a for a in signature.args},
serializers,
executor,
response_serializer,
middlewares,
)

def app_init(self, interface: Interface):
Expand All @@ -185,11 +208,15 @@ def app_init(self, interface: Interface):
app.add_api_route(
"/", lambda: RedirectResponse("/docs"), include_in_schema=False
)
for mid in self.middlewares.__root__:
mid.on_init()
if isinstance(mid, FastAPIMiddleware):
mid.on_app_init(app)

for method, signature in interface.iter_methods():
executor = interface.get_method_executor(method)
handler, response_model, response_class = self._create_handler(
method, signature, executor
method, signature, executor, self.middlewares
)

app.add_api_route(
Expand Down
65 changes: 65 additions & 0 deletions mlem/contrib/prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Instrumenting FastAPI app to expose metrics for prometheus
Extension type: middleware
Exposes /metrics endpoint
"""
from typing import ClassVar, List, Optional

from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator

from mlem.contrib.fastapi import FastAPIMiddleware
from mlem.utils.importing import import_string_with_local
from mlem.utils.module import get_object_requirements


class PrometheusFastAPIMiddleware(FastAPIMiddleware):
"""Middleware for FastAPI server that exposes /metrics endpoint to be scraped by Prometheus"""

type: ClassVar = "prometheus_fastapi"

metrics: List[str] = []
"""Instrumentator instance to use. If not provided, a new one will be created"""
instrumentator_cache: Optional[Instrumentator] = None

class Config:
arbitrary_types_allowed = True
exclude = {"instrumentator_cache"}

@property
def instrumentator(self):
if self.instrumentator_cache is None:
self.instrumentator_cache = self.get_instrumentator()
return self.instrumentator_cache

def on_app_init(self, app: FastAPI):
@app.on_event("startup")
async def _startup():
self.instrumentator.expose(app)

def on_init(self):
pass

def on_request(self, request):
return request

def on_response(self, request, response):
return response

def get_instrumentator(self):
instrumentator = Instrumentator()
for metric in self._iter_metric_objects():
# todo: check object type
instrumentator.add(metric)
return instrumentator

def _iter_metric_objects(self):
for metric in self.metrics:
# todo: meaningful error on import error
yield import_string_with_local(metric)

def get_requirements(self):
reqs = super().get_requirements()
for metric in self._iter_metric_objects():
reqs += get_object_requirements(metric)
return reqs
1 change: 1 addition & 0 deletions mlem/contrib/sagemaker/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def app_init(self, interface: Interface):
"invocations",
interface.get_method_signature(self.method),
interface.get_method_executor(self.method),
self.middlewares,
)
app.add_api_route(
"/invocations",
Expand Down
11 changes: 2 additions & 9 deletions mlem/core/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import shlex
import sys
from collections import defaultdict
from inspect import isabstract
from typing import (
Expand All @@ -22,7 +21,7 @@

from mlem.core.errors import ExtensionRequirementError, UnknownImplementation
from mlem.polydantic import PolyModel
from mlem.utils.importing import import_string
from mlem.utils.importing import import_string_with_local
from mlem.utils.path import make_posix


Expand Down Expand Up @@ -64,18 +63,12 @@ def load_impl_ext(

if type_name is not None and "." in type_name:
try:
# this is needed because if run from cli curdir is not checked for
# modules to import
sys.path.append(".")

obj = import_string(type_name)
obj = import_string_with_local(type_name)
if not issubclass(obj, MlemABC):
raise ValueError(f"{obj} is not subclass of MlemABC")
return obj
except ImportError:
pass
finally:
sys.path.remove(".")

eps = load_entrypoints()
for ep in eps.values():
Expand Down
5 changes: 5 additions & 0 deletions mlem/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class ExtensionLoader:
Extension("mlem.contrib.xgboost", ["xgboost"], False),
Extension("mlem.contrib.docker", ["docker"], False),
Extension("mlem.contrib.fastapi", ["fastapi", "uvicorn"], False),
Extension(
"mlem.contrib.prometheus",
["prometheus-fastapi-instrumentator"],
False,
),
Extension("mlem.contrib.callable", [], True),
Extension("mlem.contrib.rabbitmq", ["pika"], False, extra="rmq"),
Extension("mlem.contrib.github", [], True),
Expand Down
51 changes: 51 additions & 0 deletions mlem/runtime/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import abstractmethod
from typing import ClassVar, List

from pydantic import BaseModel

from mlem.core.base import MlemABC
from mlem.core.requirements import Requirements, WithRequirements


class Middleware(MlemABC, WithRequirements):
abs_name: ClassVar = "middleware"

class Config:
type_root = True

@abstractmethod
def on_init(self):
raise NotImplementedError

@abstractmethod
def on_request(self, request):
raise NotImplementedError

@abstractmethod
def on_response(self, request, response):
raise NotImplementedError


class Middlewares(BaseModel):
__root__: List[Middleware] = []
"""Middlewares to add to server"""

def on_init(self):
for middleware in self.__root__:
middleware.on_init()

def on_request(self, request):
for middleware in self.__root__:
request = middleware.on_request(request)
return request

def on_response(self, request, response):
for middleware in reversed(self.__root__):
response = middleware.on_response(request, response)
return response

def get_requirements(self) -> Requirements:
reqs = Requirements.new()
for m in self.__root__:
reqs += m.get_requirements()
return reqs
16 changes: 14 additions & 2 deletions mlem/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
InterfaceDescriptor,
InterfaceMethod,
)
from mlem.runtime.middleware import Middlewares
from mlem.utils.module import get_object_requirements

MethodMapping = Dict[str, str]
Expand Down Expand Up @@ -120,6 +121,9 @@ class Config:
additional_source_files: ClassVar[Optional[List[str]]] = None
port_field: ClassVar[Optional[str]] = None

middlewares: Middlewares = Middlewares()
"""Middlewares to add to server"""

# @validator("interface")
# @classmethod
# def validate_interface(cls, value):
Expand Down Expand Up @@ -155,8 +159,16 @@ def _get_serializers(
return arg_serializers, returns

def get_requirements(self) -> Requirements:
return super().get_requirements() + get_object_requirements(
[self.request_serializer, self.response_serializer, self.methods]
return (
super().get_requirements()
+ get_object_requirements(
[
self.request_serializer,
self.response_serializer,
self.methods,
]
)
+ self.middlewares.get_requirements()
)

def get_ports(self) -> List[int]:
Expand Down
10 changes: 10 additions & 0 deletions mlem/utils/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def module_imported(module_name):
return sys.modules.get(module_name) is not None


def import_string_with_local(path):
try:
# this is needed because if run from cli curdir is not checked for
# modules to import
sys.path.append(".")
return import_string(path)
finally:
sys.path.remove(".")


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"xgboost": ["xgboost"],
"lightgbm": ["lightgbm"],
"fastapi": ["uvicorn", "fastapi"],
"prometheus": ["prometheus-fastapi-instrumentator"],
"streamlit": ["uvicorn", "fastapi", "streamlit", "streamlit_pydantic"],
"sagemaker": ["docker", "boto3", "sagemaker"],
"torch": ["torch"],
Expand Down Expand Up @@ -214,6 +215,7 @@
"serializer.pil_numpy = mlem.contrib.pil:PILImageSerializer",
"builder.pip = mlem.contrib.pip.base:PipBuilder",
"builder.whl = mlem.contrib.pip.base:WhlBuilder",
"middleware.prometheus_fastapi = mlem.contrib.prometheus:PrometheusFastAPIMiddleware",
"client.rmq = mlem.contrib.rabbitmq:RabbitMQClient",
"server.rmq = mlem.contrib.rabbitmq:RabbitMQServer",
"builder.requirements = mlem.contrib.requirements:RequirementsBuilder",
Expand Down
Loading

0 comments on commit 4c5be67

Please sign in to comment.