diff --git a/mlchain/__init__.py b/mlchain/__init__.py index 5d45168..2ccaa58 100644 --- a/mlchain/__init__.py +++ b/mlchain/__init__.py @@ -1,9 +1,20 @@ # Parameters of MLchain -__version__ = "0.1.8rc1" +__version__ = "0.1.9" HOST = "https://www.api.mlchain.ml" WEB_HOST = HOST API_ADDRESS = HOST MODEL_ID = None +import ssl + +try: + _create_unverified_https_context = ssl._create_unverified_context +except AttributeError: + # Legacy Python that doesn't verify HTTPS certificates by default + pass +else: + # Handle target environment that doesn't support HTTPS verification + ssl._create_default_https_context = _create_unverified_https_context + from os import environ environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' @@ -11,4 +22,5 @@ from .context import mlchain_context from .base.exceptions import * -from .config import mlconfig \ No newline at end of file +from .config import mlconfig + diff --git a/mlchain/base/exceptions.py b/mlchain/base/exceptions.py index 7b24b7a..1a8c13e 100644 --- a/mlchain/base/exceptions.py +++ b/mlchain/base/exceptions.py @@ -1,17 +1,22 @@ +import os import traceback -from .log import logger - +from .log import logger, sentry_ignore_logger +from sentry_sdk import capture_exception +import logging +from sentry_sdk import add_breadcrumb +import re class MlChainError(Exception): """Base class for all exceptions.""" def __init__(self, msg, code='exception', status_code=500): + super(MlChainError, self).__init__(msg) self.msg = msg + self.message = msg self.code = code self.status_code = status_code - logger.error("[{0}]: {1}".format(code, msg)) - logger.debug(traceback.format_exc()) - + sentry_ignore_logger.error("[{0}]: {1}".format(code, msg)) + sentry_ignore_logger.debug(traceback.format_exc()) class MLChainAssertionError(MlChainError): def __init__(self, msg, code="assertion", status_code=422): @@ -26,3 +31,11 @@ def __init__(self, msg, code="serialization", status_code=422): class MLChainUnauthorized(MlChainError): def __init__(self, msg, code="unauthorized", status_code=401): MlChainError.__init__(self, msg, code, status_code) + +class MLChainConnectionError(MlChainError): + def __init__(self, msg, code="connection_error", status_code=500): + MlChainError.__init__(self, msg, code, status_code) + +class MLChainTimeoutError(MlChainError): + def __init__(self, msg, code="timeout", status_code=500): + MlChainError.__init__(self, msg, code, status_code) \ No newline at end of file diff --git a/mlchain/base/gunicorn_config.py b/mlchain/base/gunicorn_config.py new file mode 100644 index 0000000..9c26f9e --- /dev/null +++ b/mlchain/base/gunicorn_config.py @@ -0,0 +1,4 @@ +from mlchain.config import init_sentry + +def post_worker_init(worker): + init_sentry() \ No newline at end of file diff --git a/mlchain/base/log.py b/mlchain/base/log.py index 2b16b07..8320e51 100644 --- a/mlchain/base/log.py +++ b/mlchain/base/log.py @@ -4,17 +4,28 @@ """ from contextlib import contextmanager import re +import traceback from traceback import StackSummary, extract_tb import os import sys import logging -import traceback + +# Sentry integration +from sentry_sdk.integrations.logging import LoggingIntegration +from sentry_sdk.integrations.logging import ignore_logger + +sentry_logging = LoggingIntegration( + level=logging.INFO, # Capture info and above as breadcrumbs + event_level=logging.ERROR # Send errors as events +) +ignore_logger("mlchain-server") +sentry_ignore_logger = logging.getLogger("mlchain-server") +# End sentry integration def get_color(n): return '\x1b[3{0}m'.format(n) - class MultiLine(logging.Formatter): def __init__(self, fmt=None, datefmt=None, style='%', newline=None): logging.Formatter.__init__(self, fmt, datefmt, style) @@ -60,7 +71,7 @@ def except_handler(): sys.excepthook = sys.__excepthook__ -def format_exc(name='mlchain', tb=None, exception=None): +def format_exc(name='mlchain', tb=None, exception=None, return_str=True): if exception is None: formatted_lines = traceback.format_exc().splitlines() else: @@ -78,4 +89,7 @@ def format_exc(name='mlchain', tb=None, exception=None): output = [] for x in formatted_lines: output.append(x) - return "\n".join(output) + "\n" + + if return_str: + return "\n".join(output) + return output diff --git a/mlchain/base/serve_model.py b/mlchain/base/serve_model.py index a8e9b85..4c02695 100644 --- a/mlchain/base/serve_model.py +++ b/mlchain/base/serve_model.py @@ -311,7 +311,7 @@ def call_function(self, function_name_, id_=None, *args, **kwargs): function_name, uid = function_name_, id_ if function_name is None: raise AssertionError("You need to specify the function name (API name)") - mlchain_context['context_id'] = uid + if isinstance(function_name, str): if len(function_name) == 0: if hasattr(self.model, '__call__') and callable(getattr(self.model, '__call__')): @@ -339,7 +339,7 @@ async def call_async_function(self, function_name_, id_=None, *args, **kwargs): function_name, uid = function_name_, id_ if function_name is None: raise MLChainAssertionError("You need to specify the function name (API name)") - mlchain_context['context_id'] = uid + if isinstance(function_name, str): if len(function_name) == 0: if hasattr(self.model, '__call__') and callable(getattr(self.model, '__call__')): diff --git a/mlchain/base/wrapper.py b/mlchain/base/wrapper.py index 65382de..9294545 100644 --- a/mlchain/base/wrapper.py +++ b/mlchain/base/wrapper.py @@ -23,6 +23,9 @@ def load_config(self): for key, value in config.items(): self.cfg.set(key.lower(), value) + from mlchain.base.gunicorn_config import post_worker_init + self.cfg.set("post_worker_init", post_worker_init) + def load(self): return self.application diff --git a/mlchain/cli/config.yaml b/mlchain/cli/config.yaml deleted file mode 100644 index 32cb77c..0000000 --- a/mlchain/cli/config.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: mlchain-server # name of service -version: '0.0.1' -entry_file: server.py # python file contains object ServeModel -host: localhost # host service -port: 8001 # port service -server: flask # option flask or quart or grpc -wrapper: None # option None or gunicorn or hypercorn -cors: true -static_folder: # static folder for TemplateResponse -static_url_path: # static url path for TemplateResponse -template_folder: # template folder for TemplateResponse -gunicorn: # config apm-server if uses gunicorn wrapper - timeout: 60 - keepalive: 60 - max_requests: 0 - threads: 1 - worker_class: 'gthread' - umask: '0' -hypercorn: # config apm-server if uses hypercorn wrapper - keep_alive_timeout: 60 - worker_class: 'asyncio' - umask: 0 \ No newline at end of file diff --git a/mlchain/cli/init.py b/mlchain/cli/init.py index 593a692..2051b3b 100644 --- a/mlchain/cli/init.py +++ b/mlchain/cli/init.py @@ -1,18 +1,20 @@ -import os import click -from mlchain import logger +import os root_path = os.path.dirname(__file__) - @click.command("init", short_help="Init base config to run server.") -@click.argument('file', nargs=1, required=False, default='mlconfig.yaml') -def init_command(file): - if file is None: - file = 'mlconfig.yaml' # pragma: no cover - if os.path.exists(file): - logger.warning("File {} exists. Please change name file".format(file)) - else: - with open(file, 'wb') as fp: - with open(os.path.join(root_path, 'config.yaml'), 'rb') as fr: - fp.write(fr.read()) +def init_command(): + def create_file(file): + with open(file, 'wb') as f: + f.write(open(os.path.join(root_path, file), 'rb').read()) + + ALL_INIT_FILES = ['mlconfig.yaml', 'mlchain_server.py'] + for file in ALL_INIT_FILES: + if os.path.exists(file): + if click.confirm('File {0} is exist, Do you want to force update?'.format(file)): + create_file(file) + else: + create_file(file) + + click.secho('Mlchain initalization is done!', blink=True, bold=True) \ No newline at end of file diff --git a/mlchain/cli/mlchain_server.py b/mlchain/cli/mlchain_server.py new file mode 100644 index 0000000..4c49d69 --- /dev/null +++ b/mlchain/cli/mlchain_server.py @@ -0,0 +1,20 @@ +""" +THE BASE MLCHAIN SERVER +""" +# Import mlchain +from mlchain.base import ServeModel +from mlchain import mlconfig + + +# IMPORT YOUR CLASS HERE - YOU ONLY CARE THIS +from main import Test # Import your class here + +model = Test() # Init your class first +# END YOUR WORK HERE + + +# Wrap your class by mlchain ServeModel +serve_model = ServeModel(model) + +# THEN GO TO CONSOLE: +# mlchain run -c mlconfig.yaml \ No newline at end of file diff --git a/mlchain/cli/mlconfig.yaml b/mlchain/cli/mlconfig.yaml new file mode 100644 index 0000000..77c5db0 --- /dev/null +++ b/mlchain/cli/mlconfig.yaml @@ -0,0 +1,54 @@ +# Service Config +name: mlchain-server # Name of service +version: '0.0.1' # Version of service +entry_file: mlchain_server.py # Python file contains object ServeModel + +# Host and Port Config +host: 0.0.0.0 # Host of service +port: 8001 # Port service + +# Server config +server: flask # Option flask or quart or grpc +wrapper: None # Option None or gunicorn or hypercorn +cors: true # Auto enable CORS +static_folder: # static folder for TemplateResponse +static_url_path: # static url path for TemplateResponse +template_folder: # template folder for TemplateResponse + +# Gunicorn config - Use gunicorn for general case +gunicorn: + timeout: 200 # The requests will be maximum 200 seconds in default, then when the requests is done, the worker will be restarted + keepalive: 3 # Keep requests alive when inactive with client in 3 seconds default + max_requests: 0 # Maximum serving requests until workers restart to handle over memory in Python + workers: 1 # Number of duplicate workers + threads: 1 # Number of simultaneous threads in workers + worker_class: 'gthread' # Worker class gthread is fit with all case. Can use 'uvicorn.workers.UvicornWorker' which be higher performance sometimes + +# Hypercorn config - Use hypercorn for async server with Quart +hypercorn: + timeout: 200 # The requests will be maximum 200 seconds in default, then when the requests is done, the worker will be restarted + keepalive: 3 # Keep requests alive when inactive with client in 3 seconds default + threads: 50 # Number of simultaneous threads in workers. Default: 50. Remember that some models can not call simultaneous, so you can use @non_thread() decorator to the function. + worker_class: 'uvloop' # Worker class uvloop is fit with all case. + +bind: + - 'unix:/tmp/gunicorn.sock' # Using sock to make gunicorn faster + +# Sentry logging, Sentry will be run when the worker is already initialized +sentry: + dsn: None # URI Sentry of the project or export SENTRY_DSN + traces_sample_rate: 0.1 # Default log 0.1 + sample_rate: 1.0 # Default 1.0 + drop_modules: True # Drop python requirements to lower the size of log + +# Mlconfig - Use these mode and config or env to adaptive your code +# You can import mlconfig and use as variable. Ex: mlconfig.debug +mode: + default: default # The default mode + env: + default: # All variable in default mode will be existed in other mode + test: "Hello" + dev: # Development mode + debug: True + prod: # Production mode + debug: False \ No newline at end of file diff --git a/mlchain/cli/run.py b/mlchain/cli/run.py index a0c43ca..581bb24 100644 --- a/mlchain/cli/run.py +++ b/mlchain/cli/run.py @@ -2,12 +2,13 @@ import click import importlib import sys +import copy import GPUtil from mlchain import logger from mlchain.server import MLServer from mlchain.base import ServeModel from mlchain.server.authentication import Authentication - +import traceback def select_gpu(): try: @@ -92,25 +93,33 @@ def run_command(entry_file, host, port, bind, wrapper, server, workers, config, name, mode, api_format, ngrok, kws): kws = list(kws) if isinstance(entry_file, str) and not os.path.exists(entry_file): - kws = [entry_file] + kws + kws = [f'--entry_file={entry_file}'] + kws entry_file = None from mlchain import config as mlconfig default_config = False + if config is None: default_config = True config = 'mlconfig.yaml' - if os.path.isfile(config): - config = mlconfig.load_file(config) + config_path = copy.deepcopy(config) + if os.path.isfile(config_path) and os.path.exists(config_path): + config = mlconfig.load_file(config_path) if config is None: - raise AssertionError("Not support file config {0}".format(config)) + raise SystemExit("Config file {0} are not supported".format(config_path)) else: if not default_config: - raise FileNotFoundError("Not found file {0}".format(config)) - config = {} + raise SystemExit("Can't find config file {0}".format(config_path)) + else: + raise SystemExit("Can't find mlchain config file. Please double check your current working directory. Or use `mlchain init` to initialize a new ones here.") if 'mode' in config and 'env' in config['mode']: if mode in config['mode']['env']: config['mode']['default'] = mode + elif mode is not None: + available_mode = list(config['mode']['env'].keys()) + available_mode = [each for each in available_mode if each != 'default'] + raise SystemExit( + f"No {mode} mode are available. Found these mode in config file: {available_mode}") mlconfig.load_config(config) for kw in kws: if kw.startswith('--'): @@ -124,6 +133,10 @@ def run_command(entry_file, host, port, bind, wrapper, server, workers, config, raise AssertionError("Unexpected param {0}".format(kw)) model_id = mlconfig.get_value(None, config, 'model_id', None) entry_file = mlconfig.get_value(entry_file, config, 'entry_file', 'server.py') + if entry_file.strip() == '': + raise SystemExit(f"Entry file cannot be empty") + if not os.path.exists(entry_file): + raise SystemExit(f"Entry file {entry_file} not found in current working directory.") host = mlconfig.get_value(host, config, 'host', 'localhost') port = mlconfig.get_value(port, config, 'port', 5000) server = mlconfig.get_value(server, config, 'server', 'flask') @@ -134,7 +147,9 @@ def run_command(entry_file, host, port, bind, wrapper, server, workers, config, if wrapper == 'gunicorn' and os.name == 'nt': logger.warning('Gunicorn warper are not supported on Windows. Switching to None instead.') wrapper = None - workers = mlconfig.get_value(workers, config['gunicorn'], 'workers', None) + workers = None + if 'gunicorn' in config: + workers = mlconfig.get_value(workers, config['gunicorn'], 'workers', None) if workers is None and 'hypercorn' in config.keys(): workers = mlconfig.get_value(workers, config['hypercorn'], 'workers', None) workers = int(workers) if workers is not None else 1 @@ -180,6 +195,10 @@ def run_command(entry_file, host, port, bind, wrapper, server, workers, config, if server == 'grpc': from mlchain.server.grpc_server import GrpcServer app = get_model(entry_file, serve_model=True) + + if app is None: + raise Exception("Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!".format(entry_file)) + app = GrpcServer(app, name=name) app.run(host, port) elif wrapper == 'gunicorn': @@ -200,6 +219,9 @@ def load_config(self): for key, value in config.items(): self.cfg.set(key.lower(), value) + from mlchain.base.gunicorn_config import post_worker_init + self.cfg.set("post_worker_init", post_worker_init) + def load(self): original_cuda_variable = os.environ.get('CUDA_VISIBLE_DEVICES') if original_cuda_variable is None: @@ -207,6 +229,10 @@ def load(self): else: logger.info(f"Skipping automatic GPU selection for gunicorn worker since CUDA_VISIBLE_DEVICES environment variable is already set to {original_cuda_variable}") serve_model = get_model(entry_file, serve_model=True) + + if serve_model is None: + raise Exception(f"Can not init model class from {entry_file}. Please check mlconfig.yaml or {entry_file} or mlchain run -m {{mode}}!") + if isinstance(serve_model, ServeModel): if (not self.autofrontend) and model_id is not None: from mlchain.server.autofrontend import register_autofrontend @@ -259,10 +285,15 @@ def load(self): if 'uvicorn' in gunicorn_config['worker_class']: logger.warning("Can't use flask with uvicorn. change to gthread") gunicorn_config['worker_class'] = 'gthread' + GunicornWrapper(server, bind=bind, **gunicorn_config).run() elif wrapper == 'hypercorn' and server == 'quart': from mlchain.server.quart_server import QuartServer app = get_model(entry_file, serve_model=True) + + if app is None: + raise Exception("Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!".format(entry_file)) + app = QuartServer(app, name=name, version=version, api_format=api_format, authentication=authentication, static_url_path=static_url_path, @@ -272,6 +303,10 @@ def load(self): gunicorn=False, hypercorn=True, **config.get('hypercorn', {}), model_id=model_id) app = get_model(entry_file) + + if app is None: + raise Exception("Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!".format(entry_file)) + if isinstance(app, MLServer): if app.__class__.__name__ == 'FlaskServer': app.run(host, port, cors=cors, gunicorn=False) @@ -311,7 +346,12 @@ def load(self): def get_model(module, serve_model=False): import_name = prepare_import(module) - module = importlib.import_module(import_name) + try: + module = importlib.import_module(import_name) + except Exception as ex: + logger.error(traceback.format_exc()) + return None + serve_models = [v for v in module.__dict__.values() if isinstance(v, ServeModel)] if len(serve_models) > 0 and serve_model: serve_model = serve_models[0] @@ -329,5 +369,5 @@ def get_model(module, serve_model=False): serve_model = ServeModel(serve_models[-1]) return serve_model - logger.error("Could not find any instance to serve") + logger.error("Could not find any instance to serve. So please check again the mlconfig.yaml or server file!") return None diff --git a/mlchain/client/base.py b/mlchain/client/base.py index c6f7e24..68b570d 100644 --- a/mlchain/client/base.py +++ b/mlchain/client/base.py @@ -6,7 +6,38 @@ MsgpackBloscSerializer, JpgMsgpackSerializer, PngMsgpackSerializer, Serializer) from mlchain.base.log import except_handler, logger from mlchain.server.base import RawResponse - +from sentry_sdk import Hub +from httpx import ( + CloseError, + ConnectError, + ConnectTimeout, + CookieConflict, + DecodingError, + HTTPError, + HTTPStatusError, + InvalidURL, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + RequestError, + RequestNotRead, + ResponseClosed, + ResponseNotRead, + StreamConsumed, + StreamError, + TimeoutException, + TooManyRedirects, + TransportError, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from mlchain.base.exceptions import MLChainConnectionError, MLChainTimeoutError class AsyncStorage: def __init__(self, function): @@ -157,21 +188,63 @@ def _post(self, function_name, headers=None, args=None, kwargs=None): raise NotImplementedError def post(self, function_name, headers=None, args=None, kwargs=None): - context = mlchain_context.copy() - headers = self.headers() - context.update(headers) - if 'parent_id' in context: - context.pop('parent_id') - if 'context_id' in context: - context['parent_id'] = context.pop('context_id') - return self._post(function_name, context, args, kwargs) + def _call_post(): + context = {key: value + for (key, value) in mlchain_context.items() if key.startswith('MLCHAIN_CONTEXT_')} + context.update(self.headers()) + + output = None + try: + output = self._post(function_name, context, args, kwargs) + except ConnectError: + raise MLChainConnectionError(msg="Client call can not connect into Server: {0}. Function: {1}. POST".format(self.api_address, function_name)) + except TimeoutError: + raise MLChainTimeoutError(msg="Client call timeout into Server: {0}. Function: {1}. POST".format(self.api_address, function_name)) + except ReadTimeout: + raise MLChainTimeoutError(msg="Client call timeout into Server: {0}. Function: {1}. POST".format(self.api_address, function_name)) + except WriteTimeout: + raise MLChainTimeoutError(msg="Client call timeout into Server: {0}. Function: {1}. POST".format(self.api_address, function_name)) + + return output + + transaction = Hub.current.scope.transaction + + if transaction is not None: + with transaction.start_child(op="task", description="{0} {1}".format(self.api_address, function_name)) as span: + return _call_post() + else: + return _call_post() def _get(self, api_name, headers=None, timeout=None): raise NotImplementedError def get(self, api_name, headers=None, timeout=None): - return self._get(api_name, self.headers(), timeout) - + def _call_get(): + context = {key: value + for (key, value) in mlchain_context.items() if key.startswith('MLCHAIN_CONTEXT_')} + context.update(self.headers()) + + output = None + try: + output = self._get(api_name, self.headers(), timeout) + except ConnectError: + raise MLChainConnectionError(msg="Client call can not connect into Server: {0}. Function: {1}. GET".format(self.api_address, api_name)) + except TimeoutError: + raise MLChainTimeoutError(msg="Client call timeout into Server: {0}. Function: {1}. GET".format(self.api_address, api_name)) + except ReadTimeout: + raise MLChainTimeoutError(msg="Client call timeout into Server: {0}. Function: {1}. GET".format(self.api_address, api_name)) + except WriteTimeout: + raise MLChainTimeoutError(msg="Client call timeout into Server: {0}. Function: {1}. GET".format(self.api_address, api_name)) + + return output + + transaction = Hub.current.scope.transaction + + if transaction is not None: + with transaction.start_child(op="task", description="{0} {1}".format(self.api_address, api_name)) as span: + return _call_get() + else: + return _call_get() class BaseFunction: def __init__(self, client, function_name, serializer: Serializer): @@ -191,11 +264,8 @@ def __call__(self, *args, **kwargs): if 'error' in output: with except_handler(): raise Exception( - "MLCHAIN VERSION: {} API VERSION: {} ERROR_CODE: {} INFO_ERROR: {}, ".format( + "MLCHAIN VERSION: {} \n API VERSION: {} \n ERROR_CODE: {} \n INFO_ERROR: {}, ".format( output.get('mlchain_version', None), output.get('api_version', None), output.get('code', None), output['error'])) - logger.debug("MLCHAIN VERSION: {} API VERSION: {}".format( - output.get('mlchain_version', None), - output.get('api_version', None))) return output['output'] diff --git a/mlchain/client/http_client.py b/mlchain/client/http_client.py index 173458a..c14d904 100644 --- a/mlchain/client/http_client.py +++ b/mlchain/client/http_client.py @@ -6,7 +6,6 @@ from mlchain.server.base import RawResponse, JsonResponse from .base import MLClient - class HttpClient(MLClient): def __init__(self, api_key=None, api_address=None, serializer='msgpack', image_encoder=None, name=None, version='lastest', @@ -142,6 +141,9 @@ def _post(self, function_name, headers=None, args=None, kwargs=None): if not self.check_response_ok(output): if output.status_code == 404: raise Exception("This request url is not found") + else: + raise Exception("There 's some error when calling, please check: \n HTTP ERROR: {0} \n DETAIL: ".format(output.status_code, output.content)) + return output.content return output.content diff --git a/mlchain/config.py b/mlchain/config.py index 954257a..9d88e6c 100644 --- a/mlchain/config.py +++ b/mlchain/config.py @@ -1,7 +1,12 @@ import os from os import environ from collections import defaultdict - +from .base.log import logger +import datetime +import sentry_sdk +from sentry_sdk.integrations.flask import FlaskIntegration +import datetime +from mlchain.utils.system_info import get_gpu_statistics class BaseConfig(dict): def __init__(self, env_key='', **kwargs): @@ -13,29 +18,27 @@ def __getattr__(self, item): r = self.get_item(item) if r is not None: return r + r = self.get_item(item.upper()) if r is not None: return r + r = self.get_item(item.lower()) if r is not None: return r - r = self.get_default(item) - return r - def get_item(self, item): if item.upper() in self: return self[item.upper()] + + r = self.get_default(item) + return r - r = environ.get(self.env_key.upper() + item) - if r is not None: - return r - - r = environ.get(self.env_key.lower() + item) + def get_item(self, item): + r = environ.get(item) if r is not None: return r - r = environ.get(item) - return r + return None def from_json(self, path): import json @@ -98,7 +101,8 @@ def load_config(self, path, mode=None): for mode in ['default', default]: if mode in data['mode']['env']: for k, v in data['mode']['env'][mode].items(): - environ[k] = str(v) + if k in environ: + data['mode']['env'][mode][k] = environ[k] self.update(data['mode']['env'][mode]) def get_client_config(self, name): @@ -112,25 +116,82 @@ def get_client_config(self, name): def load_config(data): + mlconfig.update({ + "MLCHAIN_SERVER_NAME": data.get("name", 'mlchain-server'), + "MLCHAIN_SERVER_VERSION": data.get("version", "0.0.1") + }) + + default = 'default' + if 'mode' in data: + if 'default' in data['mode']: + default = data['mode']['default'] + + mlconfig.update({ + "MLCHAIN_DEFAULT_MODE": default + }) + + if "sentry" in data: + mlconfig.update({ + "MLCHAIN_SENTRY_DSN": os.getenv("SENTRY_DSN", data['sentry'].get("dsn", None)), + "MLCHAIN_SENTRY_TRACES_SAMPLE_RATE": os.getenv("SENTRY_TRACES_SAMPLE_RATE", data['sentry'].get("traces_sample_rate", 0.1)), + "MLCHAIN_SENTRY_SAMPLE_RATE": os.getenv("SENTRY_SAMPLE_RATE", data['sentry'].get("sample_rate", 1.0)), + "MLCHAIN_SENTRY_DROP_MODULES": os.getenv("SENTRY_DROP_MODULES", data['sentry'].get("drop_modules", 'True')) not in ['False', 'false', False] + }) + for config in all_configs: env_key = config.env_key.strip('_').lower() if env_key in data: config.update(data[env_key]) + if 'clients' in data: mlconfig.update_client(data['clients']) - if 'mode' in data: - if 'default' in data['mode']: - default = data['mode']['default'] - else: - default = 'default' + if 'mode' in data: if 'env' in data['mode']: for mode in ['default', default]: if mode in data['mode']['env']: for k, v in data['mode']['env'][mode].items(): - environ[k] = str(v) + if k in environ: + data['mode']['env'][mode][k] = environ[k] mlconfig.update(data['mode']['env'][mode]) + + if (mlconfig.MLCHAIN_SENTRY_DSN is not None and mlconfig.MLCHAIN_SENTRY_DSN != 'None') and data.get('wrapper', None) != 'gunicorn': + init_sentry() +def before_send(event, hint): + if mlconfig.MLCHAIN_SENTRY_DROP_MODULES: + event['modules'] = {} + + event['extra']["gpuinfo"] = get_gpu_statistics() + return event + +def init_sentry(): + if mlconfig.MLCHAIN_SENTRY_DSN is None or mlconfig.MLCHAIN_SENTRY_DSN == 'None': + return None + logger.debug("Initializing Sentry to {0} and traces_sample_rate: {1} and sample_rate: {2} and drop_modules: {3}".format(mlconfig.MLCHAIN_SENTRY_DSN, mlconfig.MLCHAIN_SENTRY_TRACES_SAMPLE_RATE, mlconfig.MLCHAIN_SENTRY_SAMPLE_RATE, mlconfig.MLCHAIN_SENTRY_DROP_MODULES)) + try: + sentry_sdk.init( + dsn=mlconfig.MLCHAIN_SENTRY_DSN, + integrations=[FlaskIntegration()], + sample_rate=mlconfig.MLCHAIN_SENTRY_SAMPLE_RATE, + traces_sample_rate=mlconfig.MLCHAIN_SENTRY_TRACES_SAMPLE_RATE, + server_name=mlconfig.MLCHAIN_SERVER_NAME, + environment=mlconfig.MLCHAIN_DEFAULT_MODE, + before_send=before_send + ) + + sentry_sdk.set_context( + key = "app", + value = { + "app_start_time": datetime.datetime.now(), + "app_name": str(mlconfig.MLCHAIN_SERVER_NAME), + "app_version": str(mlconfig.MLCHAIN_SERVER_VERSION), + } + ) + logger.info("Initialized Sentry to {0} and traces_sample_rate: {1} and sample_rate: {2} and drop_modules: {3}".format(mlconfig.MLCHAIN_SENTRY_DSN, mlconfig.MLCHAIN_SENTRY_TRACES_SAMPLE_RATE, mlconfig.MLCHAIN_SENTRY_SAMPLE_RATE, mlconfig.MLCHAIN_SENTRY_DROP_MODULES)) + except sentry_sdk.utils.BadDsn: + if 'http' in mlconfig.MLCHAIN_SENTRY_DSN: + raise SystemExit("Sentry DSN configuration is invalid") def load_json(path): import json diff --git a/mlchain/context.py b/mlchain/context.py index a3bcab8..499fcf2 100644 --- a/mlchain/context.py +++ b/mlchain/context.py @@ -1,7 +1,6 @@ import contextvars from copy import deepcopy - class MLChainContext: variables = contextvars.ContextVar("mlchain_variables") @@ -63,7 +62,7 @@ def update(self, vars: dict): variables.update(vars) self.variables.set(variables) - def get(self): + def to_dict(self): try: variables = self.variables.get() if variables is None: @@ -84,5 +83,10 @@ def copy(self): def set(self, variables): self.variables.set(variables) + def __getattr__(self, item): + return self.__getitem__(item) + + def set_mlchain_context_id(self, value: str): + self.update({"MLCHAIN_CONTEXT_ID": value}) mlchain_context = MLChainContext() diff --git a/mlchain/server/flask_server.py b/mlchain/server/flask_server.py index 3895518..b6b5398 100644 --- a/mlchain/server/flask_server.py +++ b/mlchain/server/flask_server.py @@ -144,7 +144,6 @@ def __call__(self, *args, **kwargs): return response_function(output, 200) except MlChainError as ex: err = ex.msg - logger.error("code: {0} msg: {1}".format(ex.code, ex.msg)) output = { 'error': err, @@ -156,7 +155,6 @@ def __call__(self, *args, **kwargs): return response_function(output, ex.status_code) except AssertionError as ex: err = str(ex) - logger.error(err) output = { 'error': err, @@ -166,8 +164,7 @@ def __call__(self, *args, **kwargs): } return response_function(output, 422) except Exception: - err = str(format_exc(name='mlchain.serve.server')) - logger.error(err) + err = format_exc(name='mlchain.serve.server', return_str=False) output = { 'error': err, @@ -255,7 +252,7 @@ def __init__(self, model: ServeModel, name=None, version='0.0', self.app.add_url_rule('/call_raw/', 'call_raw', FlaskView(self, RawFormat(), self.authentication), methods=['POST', 'GET'], strict_slashes=False) - + def _get_file_name(self, storage): return storage.filename @@ -375,6 +372,7 @@ def run(self, host='127.0.0.1', port=8080, bind=None, cors=False, cors_resources if cors: CORS(self.app, resources=cors_resources, origins=cors_allow_origins) + if not gunicorn: if bind is not None: if isinstance(bind, str): @@ -410,7 +408,8 @@ def run(self, host='127.0.0.1', port=8080, bind=None, cors=False, cors_resources logger.info("-" * 80) loglevel = kwargs.get('loglevel', 'warning' if debug else 'info') + GunicornWrapper(self.app, bind=bind, workers=workers, timeout=timeout, keepalive=keepalive, max_requests=max_requests, loglevel=loglevel, worker_class=worker_class, - threads=threads, umask=umask, **kwargs).run() + threads=threads, umask=umask, **kwargs).run() \ No newline at end of file diff --git a/mlchain/server/format.py b/mlchain/server/format.py index b5c0e50..ded6943 100644 --- a/mlchain/server/format.py +++ b/mlchain/server/format.py @@ -1,10 +1,39 @@ from typing import List, Tuple, Dict import traceback from mlchain import logger, __version__ +from mlchain.base.log import sentry_ignore_logger from mlchain.base.serializer import JsonSerializer, MsgpackSerializer, MsgpackBloscSerializer from mlchain.base.exceptions import MLChainSerializationError, MlChainError from .base import RawResponse, JsonResponse, MLChainResponse - +from sentry_sdk import add_breadcrumb, capture_exception +import re +import os + +def logging_error(exception, true_exception = None): + string_exception = "\n".join(exception) + sentry_ignore_logger.error(string_exception) + + # Log to sentry + add_breadcrumb( + category="500", + message="\n".join([x for x in exception if re.search(r"(site-packages\/mlchain\/)|(\/envs\/)|(\/anaconda)", x) is None]), + level='error', + ) + + try: + the_exception_1 = exception[-2] + except: + the_exception_1 = "" + + try: + the_exception_2 = exception[-1] + except: + the_exception_2 = "" + + if true_exception is not None: + capture_exception(true_exception) + else: + capture_exception(RuntimeError("{0} {1}".format(the_exception_1, the_exception_2))) class BaseFormat: def check(self, headers, form, files, data) -> bool: @@ -40,27 +69,33 @@ def make_response(self, function_name, headers, return JsonResponse(output, 200) else: if isinstance(exception, MlChainError): + error = exception.msg output = { 'error': exception.msg, 'code': exception.code, 'api_version': request_context.get('api_version'), 'mlchain_version': __version__ } + logging_error([error], true_exception = exception) return JsonResponse(output, exception.status_code) - if isinstance(exception, Exception): - error = ''.join(traceback.extract_tb(exception.__traceback__).format()).strip() + elif isinstance(exception, Exception): + error = traceback.extract_tb(exception.__traceback__).format() output = { 'error': error, 'api_version': request_context.get('api_version'), 'mlchain_version': __version__ } + logging_error(error, true_exception = exception) + return JsonResponse(output, 500) + else: + exception = exception.split("\n") + output = { + 'error': exception, + 'api_version': request_context.get('api_version'), + 'mlchain_version': __version__ + } + logging_error(exception) return JsonResponse(output, 500) - output = { - 'error': str(exception), - 'api_version': request_context.get('api_version'), - 'mlchain_version': __version__ - } - return JsonResponse(output, 500) class MLchainFormat(BaseFormat): @@ -112,29 +147,36 @@ def make_response(self, function_name, headers, output, status = 200 else: if isinstance(exception, MlChainError): + error = exception.msg output = { 'error': exception.msg, 'code': exception.code, 'api_version': request_context.get('api_version'), 'mlchain_version': __version__ } - status = exception.status_code - + logging_error([error], true_exception = exception) + return JsonResponse(output, exception.status_code) elif isinstance(exception, Exception): - error = ''.join(traceback.extract_tb(exception.__traceback__).format()).strip() + error = traceback.extract_tb(exception.__traceback__).format() output = { 'error': error, 'api_version': request_context.get('api_version'), 'mlchain_version': __version__ } - status = 500 + + logging_error(error, true_exception = exception) + return JsonResponse(output, 500) else: + exception = exception.split("\n") + logging_error(exception) output = { - 'error': str(exception), + 'error': exception, 'api_version': request_context.get('api_version'), 'mlchain_version': __version__ } status = 500 + return JsonResponse(output, 500) + serializer_type = headers.get('mlchain-serializer', 'json') serializer = self.serializers.get(serializer_type, None) diff --git a/mlchain/server/grpc_server.py b/mlchain/server/grpc_server.py index 4faf3a8..96efefe 100644 --- a/mlchain/server/grpc_server.py +++ b/mlchain/server/grpc_server.py @@ -34,9 +34,9 @@ def call(self, request, context): kwargs = request.kwargs serializer = self.get_serializer(header.serializer) headers = request.headers - uid = uuid4().hex + uid = str(uuid4()) mlchain_context.set(headers) - mlchain_context['context_id'] = uid + mlchain_context['MLCHAIN_CONTEXT_ID'] = uid args = serializer.decode(args) kwargs = serializer.decode(kwargs) func = self.model.get_function(function_name) diff --git a/mlchain/server/quart_server.py b/mlchain/server/quart_server.py index 3a90616..1e0b552 100644 --- a/mlchain/server/quart_server.py +++ b/mlchain/server/quart_server.py @@ -162,7 +162,7 @@ async def __call__(self, *args, **kwargs): return await response_function(output, 200) except MlChainError as ex: err = ex.msg - logger.error("code: {0} msg: {1}".format(ex.code, ex.msg)) + # logger.error("code: {0} msg: {1}".format(ex.code, ex.msg)) output = { 'error': err, @@ -174,7 +174,6 @@ async def __call__(self, *args, **kwargs): return await response_function(output, ex.status_code) except AssertionError as ex: err = str(ex) - logger.error(err) output = { 'error': err, @@ -184,8 +183,7 @@ async def __call__(self, *args, **kwargs): } return await response_function(output, 422) except Exception as ex: - err = str(format_exc(name='mlchain.serve.server')) - logger.error(err) + err = format_exc(name='mlchain.serve.server') output = { 'error': err, diff --git a/mlchain/server/view.py b/mlchain/server/view.py index f52e508..6bd3adf 100644 --- a/mlchain/server/view.py +++ b/mlchain/server/view.py @@ -9,7 +9,8 @@ from quart import Response as QuartReponse from .authentication import Authentication import traceback - +from sentry_sdk import push_scope, start_transaction +from mlchain import mlconfig class View: def __init__(self, server, formatter: BaseFormat = None, @@ -37,14 +38,29 @@ def get_format(self, headers, form, files, data): break return formatter - def init_context(self, headers): - context = {key[len('mlchain_context_'):]: value - for (key, value) in headers.items() if key.startswith('mlchain_context_')} - uid = uuid4().hex + def init_context_with_headers(self, headers, context_id:str = None): + context = {key: value + for (key, value) in headers.items()} + new_context = {} + for key, value in context.items(): + if key.lower().startswith("mlchain-context"): + new_context[key.upper().replace("-", "_")] = value + context.update(new_context) + mlchain_context.set(context) - mlchain_context['context_id'] = uid - return uid + if mlchain_context.MLCHAIN_CONTEXT_ID is None: + mlchain_context['MLCHAIN_CONTEXT_ID'] = context_id + else: + context_id = mlchain_context['MLCHAIN_CONTEXT_ID'] + + return context_id + + def init_context(self): + uid = str(uuid4()) + mlchain_context['MLCHAIN_CONTEXT_ID'] = uid + return uid + def normalize_output(self, formatter, function_name, headers, output, exception, request_context): if isinstance(output, FileResponse): @@ -66,44 +82,58 @@ def normalize_output(self, formatter, function_name, headers, return output def __call__(self, function_name, **kws): - request_context = { - 'api_version': self.server.version - } - try: - headers, form, files, data = self.parse_data() - except Exception as ex: - request_context['time_process'] = 0 - output = self.normalize_output(self.base_format, function_name, {}, - None, ex, request_context) - return self.make_response(output) - - formatter = self.get_format(headers, form, files, data) - start_time = time.time() - try: - if self.authentication is not None: - self.authentication.check(headers) - args, kwargs = formatter.parse_request(function_name, headers, form, - files, data, request_context) - func = self.server.model.get_function(function_name) - kwargs = self.server.get_kwargs(func, *args, **kwargs) - kwargs = self.server._normalize_kwargs_to_valid_format(kwargs, func) - - uid = self.init_context(headers) - output = self.server.model.call_function(function_name, uid, **kwargs) - exception = None - except MlChainError as ex: - exception = ex - output = None - except Exception: - exception = traceback.format_exc() - logger.error(exception) - output = None - - time_process = time.time() - start_time - request_context['time_process'] = time_process - output = self.normalize_output(formatter, function_name, headers, - output, exception, request_context) - return self.make_response(output) + with push_scope() as scope: + transaction_name = "{0} || {1}".format(mlconfig.MLCHAIN_SERVER_NAME, function_name) + scope.transaction = transaction_name + + with start_transaction(op="task", name=transaction_name): + uid = self.init_context() + + request_context = { + 'api_version': self.server.version + } + try: + headers, form, files, data = self.parse_data() + mlchain_context['REQUESTS_HEADERS'] = headers + mlchain_context['REQUESTS_FORM'] = form + mlchain_context['REQUESTS_FILES'] = files + mlchain_context['REQUESTS_DATA'] = data + + except Exception as ex: + request_context['time_process'] = 0 + output = self.normalize_output(self.base_format, function_name, {}, + None, ex, request_context) + return self.make_response(output) + + formatter = self.get_format(headers, form, files, data) + start_time = time.time() + try: + if self.authentication is not None: + self.authentication.check(headers) + args, kwargs = formatter.parse_request(function_name, headers, form, + files, data, request_context) + func = self.server.model.get_function(function_name) + kwargs = self.server.get_kwargs(func, *args, **kwargs) + kwargs = self.server._normalize_kwargs_to_valid_format(kwargs, func) + + uid = self.init_context_with_headers(headers, uid) + scope.set_tag("transaction_id", uid) + logger.debug("Mlchain transaction id: {0}".format(uid)) + + output = self.server.model.call_function(function_name, uid, **kwargs) + exception = None + except MlChainError as ex: + exception = ex + output = None + except Exception as ex: + exception = ex + output = None + + time_process = time.time() - start_time + request_context['time_process'] = time_process + output = self.normalize_output(formatter, function_name, headers, + output, exception, request_context) + return self.make_response(output) class ViewAsync(View): @@ -117,38 +147,48 @@ async def make_response(self, response: Union[RawResponse, FileResponse]): return super().make_response(response) async def __call__(self, function_name, **kws): - request_context = { - 'api_version': self.server.version - } - try: - headers, form, files, data = await self.parse_data() - except Exception as ex: - request_context['time_process'] = 0 - output = self.normalize_output(self.base_format, function_name, {}, - None, ex, request_context) - return await self.make_response(output) - formatter = self.get_format(headers, form, files, data) - start_time = time.time() - try: - if self.authentication is not None: - self.authentication.check(headers) - args, kwargs = formatter.parse_request(function_name, headers, form, - files, data, request_context) - func = self.server.model.get_function(function_name) - kwargs = self.server.get_kwargs(func, *args, **kwargs) - kwargs = self.server._normalize_kwargs_to_valid_format(kwargs, func) - uid = self.init_context(headers) - output = await self.server.model.call_async_function(function_name, uid, **kwargs) - exception = None - except MlChainError as ex: - exception = ex - output = None - except Exception: - exception = traceback.format_exc() - logger.error(exception) - output = None - time_process = time.time() - start_time - request_context['time_process'] = time_process - output = self.normalize_output(formatter, function_name, headers, - output, exception, request_context) - return await self.make_response(output) + with push_scope() as scope: + transaction_name = "{0} || {1}".format(mlconfig.MLCHAIN_SERVER_NAME, function_name) + scope.transaction = transaction_name + + with start_transaction(op="task", name=transaction_name): + uid = self.init_context() + + request_context = { + 'api_version': self.server.version + } + try: + headers, form, files, data = await self.parse_data() + except Exception as ex: + request_context['time_process'] = 0 + output = self.normalize_output(self.base_format, function_name, {}, + None, ex, request_context) + return await self.make_response(output) + formatter = self.get_format(headers, form, files, data) + start_time = time.time() + try: + if self.authentication is not None: + self.authentication.check(headers) + args, kwargs = formatter.parse_request(function_name, headers, form, + files, data, request_context) + func = self.server.model.get_function(function_name) + kwargs = self.server.get_kwargs(func, *args, **kwargs) + kwargs = self.server._normalize_kwargs_to_valid_format(kwargs, func) + uid = self.init_context_with_headers(headers, uid) + scope.set_tag("transaction_id", uid) + logger.debug("Mlchain transaction id: {0}".format(uid)) + + output = await self.server.model.call_async_function(function_name, uid, **kwargs) + exception = None + except MlChainError as ex: + exception = ex + output = None + except Exception as ex: + exception = ex + output = None + + time_process = time.time() - start_time + request_context['time_process'] = time_process + output = self.normalize_output(formatter, function_name, headers, + output, exception, request_context) + return await self.make_response(output) diff --git a/mlchain/storage/__init__.py b/mlchain/storage/__init__.py index 3e7ed44..fc851ca 100644 --- a/mlchain/storage/__init__.py +++ b/mlchain/storage/__init__.py @@ -1,4 +1,2 @@ from mlchain import logger -from .base import MLStorage, Path -logger.warn("mlchain.storage is deprecated and will be remove in the next version. " - "Please use mlchain_extension.storage instead") +from .base import MLStorage, Path \ No newline at end of file diff --git a/mlchain/utils/__init__.py b/mlchain/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlchain/utils/system_info.py b/mlchain/utils/system_info.py new file mode 100644 index 0000000..be433d4 --- /dev/null +++ b/mlchain/utils/system_info.py @@ -0,0 +1,51 @@ +""" +The code is referenced from https://github.com/jacenkow/gpu-sentry/blob/master/gpu_sentry/client.py +""" +from pynvml import ( + NVMLError, + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlDeviceGetName, + nvmlInit, +) +from mlchain.base.log import logger + +def _convert_kb_to_gb(size): + """Convert given size in kB to GB with 2-decimal places rounding.""" + return round(size / 1024 ** 3, 2) + +class GPUStats: + def __init__(self): + try: + nvmlInit() + self.has_gpu = True + except Exception as error: + logger.debug(f"Cannot get GPU info: {error}") + self.has_gpu = False + if self.has_gpu: + self.gpu_count = nvmlDeviceGetCount() + + def get_gpu_statistics(self): + """Get statistics for each GPU installed in the system.""" + if not self.has_gpu: + return [] + statistics = [] + for i in range(self.gpu_count): + handle = nvmlDeviceGetHandleByIndex(i) + memory = nvmlDeviceGetMemoryInfo(handle) + statistics.append({ + "gpu": i, + "name": nvmlDeviceGetName(handle).decode("utf-8"), + "memory": { + "total": _convert_kb_to_gb(int(memory.total)), + "used": _convert_kb_to_gb(int(memory.used)), + "utilisation": int(memory.used / memory.total * 100) + }, + }) + return statistics + +gpu_stats = GPUStats() + +def get_gpu_statistics(): + return gpu_stats.get_gpu_statistics() \ No newline at end of file diff --git a/mlchain/workflows/parallel.py b/mlchain/workflows/parallel.py index 277e34a..0e6b7e0 100644 --- a/mlchain/workflows/parallel.py +++ b/mlchain/workflows/parallel.py @@ -6,7 +6,6 @@ class TrioProgress(trio.abc.Instrument): - def __init__(self, total, notebook_mode=False, **kwargs): if notebook_mode: # pragma: no cover from tqdm.notebook import tqdm @@ -29,8 +28,15 @@ class Parallel: :verbose: Print error or not """ - def __init__(self, tasks: [], max_threads: int = 10, max_retries: int = 0, - pass_fail_job: bool = False, verbose: bool = True, threading: bool = True): + def __init__( + self, + tasks: [], + max_threads: int = 10, + max_retries: int = 0, + pass_fail_job: bool = False, + verbose: bool = True, + threading: bool = True, + ): """ :tasks: [Task, function] items :max_threads: Maximum threads to Parallel, max_threads=0 means no limitation @@ -40,8 +46,8 @@ def __init__(self, tasks: [], max_threads: int = 10, max_retries: int = 0, """ assert isinstance(tasks, list) and all( - callable(task) for task in tasks), \ - 'You have to transfer a list of callable instances or mlchain.Task' + callable(task) for task in tasks + ), "You have to transfer a list of callable instances or mlchain.Task" self.tasks = tasks if max_threads == -1: max_threads = 100 @@ -64,7 +70,9 @@ def update_progress_bar(self): if self.show_progress_bar: self.progress_bar.task_processed() - async def __call_sync(self, task, outputs, idx, limiter, max_retries=1, pass_fail_job=False): + async def __call_sync( + self, task, outputs, idx, limiter, max_retries=1, pass_fail_job=False + ): if limiter is not None: async with limiter: for retry_idx in range(max_retries): @@ -76,14 +84,17 @@ async def __call_sync(self, task, outputs, idx, limiter, max_retries=1, pass_fai if retry_idx == max_retries - 1 and not pass_fail_job: with except_handler(): raise AssertionError( - "ERROR in {}th task.\n".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + "ERROR in {}th task.\n {1}".format(idx, format_exc(name="mlchain.workflows.parallel")) + ) if retry_idx < max_retries - 1 or not self.verbose: - logger.error("PARALLEL ERROR in {0}th task and retry task," - " run times = {1}".format(idx, retry_idx + 1)) + logger.error( + "PARALLEL ERROR in {0}th task and retry task," + " run times = {1}".format(idx, retry_idx + 1) + ) else: - logger.debug("PASSED PARALLEL ERROR in {}th task:".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + logger.debug( + "PASSED PARALLEL ERROR in {}th task:".format(idx, format_exc(name="mlchain.workflows.parallel")) + ) else: for retry_idx in range(max_retries): try: @@ -93,18 +104,23 @@ async def __call_sync(self, task, outputs, idx, limiter, max_retries=1, pass_fai except Exception: if retry_idx == max_retries - 1 and not pass_fail_job: with except_handler(): - raise AssertionError("ERROR in {}th task\n".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + raise AssertionError( + "ERROR in {}th task\n".format(idx, format_exc(name="mlchain.workflows.parallel")) + ) if retry_idx < max_retries - 1 or not self.verbose: - logger.error("PARALLEL ERROR in {0}th task and retry task," - " run times = {1}".format(idx, retry_idx + 1)) + logger.error( + "PARALLEL ERROR in {0}th task and retry task," + " run times = {1}".format(idx, retry_idx + 1) + ) else: - logger.debug("PASSED PARALLEL ERROR: " - + format_exc(name='mlchain.workflows.parallel')) + logger.debug( + "PASSED PARALLEL ERROR: {0}".format(format_exc(name="mlchain.workflows.parallel")) + ) self.update_progress_bar() - async def __call_async(self, task, outputs, idx, limiter, - max_retries=1, pass_fail_job=False): + async def __call_async( + self, task, outputs, idx, limiter, max_retries=1, pass_fail_job=False + ): if limiter is not None: async with limiter: for retry_idx in range(max_retries): @@ -115,14 +131,24 @@ async def __call_async(self, task, outputs, idx, limiter, except Exception: if retry_idx == max_retries - 1 and not pass_fail_job: with except_handler(): - raise AssertionError("ERROR in {}th task\n".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + raise AssertionError( + "ERROR in {0}th task\n {1}".format( + idx, + format_exc(name="mlchain.workflows.parallel"), + ) + ) + if retry_idx < max_retries - 1 or not self.verbose: - logger.error("PARALLEL ERROR in {0}th task and retry task, " - "run times = {1}".format(idx, retry_idx + 1)) + logger.error( + "PARALLEL ERROR in {0}th task and retry task, " + "run times = {1}".format(idx, retry_idx + 1) + ) else: - logger.debug("PASSED PARALLEL ERROR in {}th task:".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + logger.debug( + "PASSED PARALLEL ERROR in {0}th task: {1}".format( + idx, format_exc(name="mlchain.workflows.parallel") + ) + ) else: for retry_idx in range(max_retries): try: @@ -132,14 +158,24 @@ async def __call_async(self, task, outputs, idx, limiter, except Exception as ex: if retry_idx == max_retries - 1 and not pass_fail_job: with except_handler(): - raise AssertionError("ERROR in {}th task\n".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + raise AssertionError( + "ERROR in {0}th task\n {1}".format( + idx, format_exc(name="mlchain.workflows.parallel") + ) + ) + if retry_idx < max_retries - 1 or not self.verbose: - logger.error("PARALLEL ERROR in {0}th task and retry task, " - "run times = {1}".format(idx, retry_idx + 1)) + logger.error( + "PARALLEL ERROR in {0}th task and retry task, " + "run times = {1}".format(idx, retry_idx + 1) + ) else: - logger.debug("PASSED PARALLEL ERROR in {}th task:".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + logger.debug( + "PASSED PARALLEL ERROR in {}th task: {1}".format( + idx, format_exc(name="mlchain.workflows.parallel") + ) + ) + self.update_progress_bar() async def dispatch(self): @@ -153,17 +189,40 @@ async def dispatch(self): async with trio.open_nursery() as nursery: for idx, task in enumerate(self.tasks): - if hasattr(task, 'to_async') and callable(task.to_async): - nursery.start_soon(self.__call_async, task.to_async(), outputs, idx, - self.limiter, self.max_retries, self.pass_fail_job) - elif inspect.iscoroutinefunction(task) \ - or (not inspect.isfunction(task) and hasattr(task, '__call__') - and inspect.iscoroutinefunction(task.__call__)): - nursery.start_soon(self.__call_async, task, outputs, idx, - self.limiter, self.max_retries, self.pass_fail_job) + if hasattr(task, "to_async") and callable(task.to_async): + nursery.start_soon( + self.__call_async, + task.to_async(), + outputs, + idx, + self.limiter, + self.max_retries, + self.pass_fail_job, + ) + elif inspect.iscoroutinefunction(task) or ( + not inspect.isfunction(task) + and hasattr(task, "__call__") + and inspect.iscoroutinefunction(task.__call__) + ): + nursery.start_soon( + self.__call_async, + task, + outputs, + idx, + self.limiter, + self.max_retries, + self.pass_fail_job, + ) else: - nursery.start_soon(self.__call_sync, task, outputs, idx, - self.limiter, self.max_retries, self.pass_fail_job) + nursery.start_soon( + self.__call_sync, + task, + outputs, + idx, + self.limiter, + self.max_retries, + self.pass_fail_job, + ) return outputs @@ -177,11 +236,14 @@ def exec_task(self, task, idx=None): if retry_idx == self.max_retries - 1 and not self.pass_fail_job: return ex if retry_idx < self.max_retries - 1 or not self.verbose: - logger.error("PARALLEL ERROR in {0}th task and retry task, " - "run times = {1}".format(idx, retry_idx + 1)) + logger.error( + "PARALLEL ERROR in {0}th task and retry task, " + "run times = {1}".format(idx, retry_idx + 1) + ) else: - logger.debug("PASSED PARALLEL ERROR in {}th task:".format(idx) - + format_exc(name='mlchain.workflows.parallel')) + logger.debug( + "PASSED PARALLEL ERROR in {}th task:".format(idx, format_exc(name="mlchain.workflows.parallel")) + ) return None def run(self, progress_bar: bool = False, notebook_mode: bool = False): @@ -194,10 +256,14 @@ def run(self, progress_bar: bool = False, notebook_mode: bool = False): pool = ThreadPool(max(1, self.max_threads)) if progress_bar: self.show_progress_bar = True - self.progress_bar = TrioProgress(total=len(self.tasks), - notebook_mode=notebook_mode) + self.progress_bar = TrioProgress( + total=len(self.tasks), notebook_mode=notebook_mode + ) - async_result = [pool.apply_async(self.exec_task, args=[task, idx]) for idx, task in enumerate(self.tasks)] + async_result = [ + pool.apply_async(self.exec_task, args=[task, idx]) + for idx, task in enumerate(self.tasks) + ] results = [] for result in async_result: @@ -211,6 +277,7 @@ def run(self, progress_bar: bool = False, notebook_mode: bool = False): return results if progress_bar: self.show_progress_bar = True - self.progress_bar = TrioProgress(total=len(self.tasks), - notebook_mode=notebook_mode) + self.progress_bar = TrioProgress( + total=len(self.tasks), notebook_mode=notebook_mode + ) return trio.run(self.dispatch) diff --git a/mlchain/workflows/task.py b/mlchain/workflows/task.py index 7a2a191..9f66bac 100644 --- a/mlchain/workflows/task.py +++ b/mlchain/workflows/task.py @@ -1,7 +1,7 @@ import inspect import trio from mlchain import mlchain_context - +from sentry_sdk import Hub class Task: """ @@ -33,14 +33,23 @@ async def __call__(self): """ Task's process code """ - if inspect.iscoroutinefunction(self.func_) \ - or (not inspect.isfunction(self.func_) - and hasattr(self.func_, '__call__') - and inspect.iscoroutinefunction(self.func_.__call__)): - async with self: - return await self.func_(*self.args, **self.kwargs) - with self: - return self.func_(*self.args, **self.kwargs) + async def _call_func(): + if inspect.iscoroutinefunction(self.func_) \ + or (not inspect.isfunction(self.func_) + and hasattr(self.func_, '__call__') + and inspect.iscoroutinefunction(self.func_.__call__)): + async with self: + return await self.func_(*self.args, **self.kwargs) + with self: + return self.func_(*self.args, **self.kwargs) + + transaction = Hub.current.scope.transaction + + if transaction is not None: + with transaction.start_child(op="task", description="{0}".format(self.func_.__name__)) as span: + return await _call_func() + else: + return await _call_func() async def __aenter__(self): return self.__enter__() diff --git a/requirements.txt b/requirements.txt index 9442db8..2d2dcf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,34 +1,38 @@ -attrs>=19.3.0 -blosc>=1.8.3; sys_platform != 'win32' +attrs>=20.3.0 +blosc>=1.10.1; sys_platform != 'win32' Click>=7.1.2 Flask>=1.1.2 -Flask-Cors>=3.0.8 +Flask-Cors>=3.0.9 gunicorn>=20.0.4 -h2==3.2.0 -Jinja2>=2.10.3 +h2==4.0.0 +Jinja2>=2.11.2 MarkupSafe>=1.1.1 -msgpack==1.0.0 -numpy<=1.18.1 +msgpack==1.0.2 +numpy<=1.19.4 opencv-python>=4.1.2.30 -Pillow==6.0.0 -Quart<=0.10.0 -Quart-CORS<=0.2.0 -requests>=2.22.0 +Pillow>=8.0.1 +Quart<=0.6.15; python_version < '3.7' +Quart<=0.14.1; python_version > '3.6' +Quart-CORS<=0.3.0 +requests>=2.25.1 six>=1.13.0 toml>=0.10.0 -trio>=0.13.0 -urllib3>=1.25.7 -uvicorn<=0.11.5 +trio>=0.17.0 +urllib3>=1.26.2 +uvicorn<=0.13.2 uvloop>=0.14.0; sys_platform != 'win32' -Werkzeug>=0.15.0 -httpx==0.13.3 -hypercorn>=0.5.4 -grpcio==1.27.2 +Werkzeug>=1.0.1 +httpx==0.16.1 +hypercorn<=0.5.4; python_version < '3.7' +hypercorn>=0.11.1; python_version > '3.6' +grpcio protobuf>=3.10.0 -boto3>=1.9.66 -pyyaml>=5.1 +boto3>=1.16.43 +pyyaml>=5.3.1 +sentry-sdk[flask]>=0.19.5 fuzzywuzzy GPUtil tqdm pyngrok -python-Levenshtein \ No newline at end of file +python-Levenshtein +pynvml \ No newline at end of file diff --git a/setup.py b/setup.py index cb414f2..bdde842 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import os from setuptools import setup, find_packages -__version__ = "0.1.8rc1" +__version__ = "0.1.9" project = "mlchain" @@ -25,7 +25,7 @@ def parse_requirements(filename): url='http://github.com/Techainer/mlchain-python', author='Techainer Inc.', author_email='admin@techainer.com', - package_data={'mlchain.cli': ['config.yaml'],'mlchain.server':['static/*','templates/*','templates/swaggerui/*']}, + package_data={'mlchain.cli': ['mlconfig.yaml', 'mlchain_server.py'],'mlchain.server':['static/*','templates/*','templates/swaggerui/*']}, include_package_data=True, packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), install_requires=install_requires, diff --git a/tests/dummy_server/server.py b/tests/dummy_server/server.py index a32a17c..f310a3d 100644 --- a/tests/dummy_server/server.py +++ b/tests/dummy_server/server.py @@ -10,7 +10,7 @@ def __init__(self): pass - def predict(self, image: np.ndarray): + def predict(self, image: np.ndarray = None): """ Resize input to 100 by 100. Args: @@ -18,6 +18,8 @@ def predict(self, image: np.ndarray): Returns: The image (np.ndarray) at 100 by 100. """ + if image is None: + return 'Hihi' image = cv2.resize(image, (100, 100)) return image diff --git a/tests/test_limiter.py b/tests/test_limiter.py index 4654e67..acb1883 100644 --- a/tests/test_limiter.py +++ b/tests/test_limiter.py @@ -30,5 +30,30 @@ def test_limiter_2(self): total_time = time.time() - start_time assert total_time >= 3 + def test_limiter_fail(self): + try: + limiter = RateLimiter(max_calls=1, period=0) + except ValueError: + pass + + try: + limiter = RateLimiter(max_calls=0, period=1) + except ValueError: + pass + + def test_limiter_with_callback(self): + start_time = time.time() + global abc + abc = 0 + def callback(i): + global abc + abc += 1 + limiter = RateLimiter(max_calls=3, period=1, callback=callback) + for i in range(10): + with limiter: + pass + total_time = time.time() - start_time + assert total_time >= 3 + if __name__ == '__main__': unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..95d2517 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,16 @@ +import logging +import unittest + +from mlchain.base.utils import * +logger = logging.getLogger() + +class TestUtils(unittest.TestCase): + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + logger.info("Running utils test") + + def test_nothing(self): + pass + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_workflow.py b/tests/test_workflow.py index bf69d2a..5d71991 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -104,6 +104,14 @@ def dummy_task(): logger.info(x) background.stop() + try: + background = Background(task, interval=0.01).run(pass_fail_job=False) + time.sleep(0.02) + logger.info(x) + background.stop() + except: + pass + def test_mlchain_async_task(self): async def dummy_task(n): return n+1