From 1a5673b9d561aa238492dfd897a6a3df358aa925 Mon Sep 17 00:00:00 2001 From: michelia Date: Tue, 17 Dec 2024 20:36:57 +0800 Subject: [PATCH] refactor: speed up model estimate --- vox_box/estimator/estimate.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vox_box/estimator/estimate.py b/vox_box/estimator/estimate.py index 537c03e..f7428a2 100644 --- a/vox_box/estimator/estimate.py +++ b/vox_box/estimator/estimate.py @@ -6,6 +6,7 @@ from vox_box.estimator.faster_whisper import FasterWhisper from vox_box.estimator.funasr import FunASR from vox_box.utils.model import create_model_dict +from concurrent.futures import ThreadPoolExecutor, as_completed def estimate_model(cfg: Config) -> Dict: @@ -25,7 +26,14 @@ def estimate_model(cfg: Config) -> Dict: def get_model_info(estimator: Estimator) -> Dict: return estimator.model_info() - for e in estimators: - model_info = e.model_info() - if model_info["supported"]: - return model_info + with ThreadPoolExecutor() as executor: + futures = {executor.submit(get_model_info, e): e for e in estimators} + for future in as_completed(futures): + result = future.result() + if result["supported"]: + for f in futures: + if not f.done(): + f.cancel() + return result + + return model_info