-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Increase MAX_CONTENT_LENGTH to 1GB * Default starlette -> flask, add parallel sync to compatible with gevent * Handle ast.literal_eval by using json.loads * Add Background Sync * Update requirements * Add itsdangerous and remove python-Levenshtein * Fixed Click version for Flask 1 * Remove Werkzeug due to Flask 1.1.4 error * Remove Jinja2 due to Flask 1.1.4 * Fixed Jinja2 and Werkzeug * Fixed h11 issue * Replace fuzzywuzzy by thefuzz and drop support python 3.6 * Fixed MarkupSafe * Remove reduntdant MarkupSafe>=1.1.1 * Re-update sentry-sdk[flask] * Does not use opencv-python 4.5 because of failed coverage test Co-authored-by: Hoang Viet <[email protected]>
- Loading branch information
1 parent
33fe470
commit 5c00860
Showing
12 changed files
with
320 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .background import Background | ||
from .parallel import Parallel | ||
from .task import Task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import inspect | ||
import time | ||
from threading import Thread, Event | ||
from .task import Task | ||
from datetime import timedelta | ||
from concurrent.futures import ThreadPoolExecutor | ||
import logging | ||
import traceback | ||
|
||
class BackgroundTask(Thread): | ||
def __init__(self, interval, task, max_repeat, callback=None, max_thread:int=1, pass_fail_job:bool=False): | ||
assert callable(task) | ||
|
||
Thread.__init__(self) | ||
self.stopped = Event() | ||
self.is_done = False | ||
self.interval = interval | ||
self.task = task | ||
self.max_repeat = max_repeat | ||
self.callback = callback | ||
self.output = None | ||
self.pool_limit = ThreadPoolExecutor(max_workers=max_thread) | ||
self.pass_fail_job = pass_fail_job | ||
|
||
if callback is not None: | ||
self.pool_limit_callback = ThreadPoolExecutor(max_workers=1) | ||
|
||
def stop(self): | ||
self.stopped.set() | ||
self.join() | ||
|
||
def get_output(self, task, *args, **kwargs): | ||
try: | ||
self.output = task(*args, **kwargs) | ||
except Exception as ex: | ||
self.output = ("MLCHAIN_BACKGROUND_ERROR", traceback.format_exc()) | ||
self.call_the_callback() | ||
|
||
def call_the_callback(self): | ||
if self.callback: | ||
self.pool_limit_callback.submit(self.callback) | ||
|
||
if isinstance(self.output, tuple) and len(self.output) == 2 and self.output[0] == "MLCHAIN_BACKGROUND_ERROR": | ||
if self.pass_fail_job: | ||
logging.error("BACKGROUND CALL ERROR: {0}".format(self.output[1])) | ||
else: | ||
raise Exception("BACKGROUND CALL ERROR: {0}".format(self.output[1])) | ||
|
||
def run(self): | ||
if self.interval is not None: | ||
count_repeat = 0 | ||
while (self.max_repeat < 0 or count_repeat < self.max_repeat) \ | ||
and (not self.stopped.wait(self.interval.total_seconds())): | ||
|
||
if isinstance(type(self.task), Task) \ | ||
or issubclass(type(self.task), Task): | ||
self.pool_limit.submit(self.get_output, self.task.func_, *self.task.args, **self.task.kwargs) | ||
else: | ||
self.pool_limit.submit(self.get_output, self.task) | ||
count_repeat += 1 | ||
else: | ||
if isinstance(type(self.task), Task) \ | ||
or issubclass(type(self.task), Task): | ||
self.pool_limit.submit(self.get_output, self.task.func_, *self.task.args, **self.task.kwargs) | ||
else: | ||
self.pool_limit.submit(self.get_output, self.task) | ||
|
||
self.pool_limit.shutdown(wait=True) | ||
self.is_done = True | ||
|
||
if isinstance(self.output, tuple) and len(self.output) == 2 and self.output[0] == "MLCHAIN_BACKGROUND_ERROR": | ||
if self.pass_fail_job: | ||
logging.error("BACKGROUND CALL ERROR: {0}".format(self.output[1])) | ||
else: | ||
raise Exception("BACKGROUND CALL ERROR: {0}".format(self.output[1])) | ||
|
||
if self.callback is not None: | ||
self.pool_limit_callback.shutdown(wait=True) | ||
self.is_done = True | ||
|
||
def wait(self, interval: float = 0.1): | ||
while not self.is_done: | ||
time.sleep(interval) | ||
return self.output | ||
|
||
def wait(self, interval: float = 0.1): | ||
while not self.is_done: | ||
time.sleep(interval) | ||
return self.output | ||
|
||
class Background: | ||
""" | ||
Run a task in background using Threading.Event | ||
:task: [Task, function] item | ||
:interval: timedelta or float seconds | ||
""" | ||
|
||
def __init__(self, task, interval:float=None, max_repeat:int=-1, callback=None): | ||
assert callable(task), 'You have to transfer a callable instance or an mlchain.Task' | ||
assert (max_repeat > 0 and interval is not None and interval > 0) or max_repeat == -1, "interval need to be set when max_repeat > 0" | ||
assert callback is None or callable(callback), "callback need to be callable" | ||
|
||
if interval is not None: | ||
if isinstance(interval, int) or isinstance(interval, float): | ||
interval = timedelta(seconds = interval) | ||
|
||
self.task = task | ||
self.interval = interval | ||
self.max_repeat = max_repeat | ||
self.callback = callback | ||
|
||
def run(self, max_thread:int=1, pass_fail_job:bool=False): | ||
task = BackgroundTask(interval=self.interval, task=self.task, | ||
max_repeat=self.max_repeat, callback=self.callback, max_thread=max_thread, pass_fail_job=pass_fail_job) | ||
task.start() | ||
|
||
return task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
from multiprocessing.pool import ThreadPool | ||
from mlchain.base.log import format_exc, except_handler, logger | ||
from typing import List | ||
|
||
class TrioProgress: | ||
def __init__(self, total, notebook_mode=False, **kwargs): | ||
if notebook_mode: # pragma: no cover | ||
from tqdm.notebook import tqdm | ||
else: | ||
from tqdm import tqdm | ||
|
||
self.tqdm = tqdm(total=total, **kwargs) | ||
self.count = 0 | ||
self.total = total | ||
|
||
def task_processed(self): | ||
self.tqdm.update(1) | ||
self.count += 1 | ||
if self.count == self.total: | ||
self.tqdm.close() | ||
|
||
class Parallel: | ||
""" | ||
Build a collection of tasks to be executed in parallel | ||
:tasks: List of [Task, function] items | ||
:max_threads: Maximum Threads for this Parallel | ||
:max_retries: Maximum retry time when a task fail | ||
:pass_fail_job: Pass or Raise error when a task run fail | ||
:verbose: Print error or not | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tasks: List, | ||
max_threads: int = 10, | ||
max_retries: int = 0, | ||
pass_fail_job: bool = False, | ||
verbose: bool = True, | ||
): | ||
""" | ||
:tasks: [Task, function] items | ||
:max_threads: Maximum threads to Parallel, max_threads=0 means no limitation | ||
:max_retries: How many time retry when job fail | ||
:pass_fail_job: No exeption when a job fail | ||
:verbose: Verbose or not | ||
""" | ||
|
||
assert isinstance(tasks, list) and all( | ||
callable(task) for task in tasks | ||
), "You have to transfer a list of callable instances or mlchain.Task" | ||
self.tasks = tasks | ||
if max_threads == -1: | ||
max_threads = 100 | ||
elif max_threads == 0: | ||
max_threads = os.cpu_count() | ||
self.max_threads = max(0, max_threads) | ||
|
||
self.max_retries = max(max_retries + 1, 1) | ||
self.pass_fail_job = pass_fail_job | ||
self.verbose = verbose | ||
self.show_progress_bar = False | ||
self.progress_bar = None | ||
|
||
def update_progress_bar(self): | ||
if self.show_progress_bar: | ||
self.progress_bar.task_processed() | ||
|
||
def exec_task(self, task, idx=None): | ||
for retry_idx in range(self.max_retries): | ||
try: | ||
output = task.exec() | ||
self.update_progress_bar() | ||
return output | ||
except Exception as ex: | ||
if retry_idx == self.max_retries - 1 and not self.pass_fail_job: | ||
return ex | ||
if retry_idx < self.max_retries - 1 or not self.verbose: | ||
logger.error( | ||
"PARALLEL ERROR in {0}th task and retry task, " | ||
"run times = {1}".format(idx, retry_idx + 1) | ||
) | ||
else: | ||
logger.debug( | ||
"PASSED PARALLEL ERROR in {}th task:".format(idx, format_exc(name="mlchain.workflows.parallel")) | ||
) | ||
return None | ||
|
||
def run(self, progress_bar: bool = False, notebook_mode: bool = False): | ||
""" | ||
When you run parallel in root, please use this function | ||
:progress_bar: Use tqdm to show the progress of calling Parallel | ||
:notebook_mode: Put it to true if run mlchain inside notebook | ||
""" | ||
pool = ThreadPool(max(1, self.max_threads)) | ||
if progress_bar: | ||
self.show_progress_bar = True | ||
self.progress_bar = TrioProgress( | ||
total=len(self.tasks), notebook_mode=notebook_mode | ||
) | ||
|
||
async_result = [ | ||
pool.apply_async(self.exec_task, args=[task, idx]) | ||
for idx, task in enumerate(self.tasks) | ||
] | ||
|
||
results = [] | ||
for result in async_result: | ||
output = result.get() | ||
if isinstance(output, Exception): | ||
pool.terminate() | ||
pool.close() | ||
raise output | ||
results.append(output) | ||
pool.close() | ||
return results |
Oops, something went wrong.