diff --git a/pdf2zh/backend.py b/pdf2zh/backend.py index 34365a4f..9fe63524 100644 --- a/pdf2zh/backend.py +++ b/pdf2zh/backend.py @@ -6,9 +6,7 @@ import tqdm import json import io -from pdf2zh.doclayout import OnnxModel - -model = OnnxModel.load_available() +from pdf2zh.pdf2zh import model flask_app = Flask("pdf2zh") flask_app.config.from_mapping( @@ -18,6 +16,7 @@ ) ) + def celery_init_app(app: Flask) -> Celery: class FlaskTask(Task): def __call__(self, *args, **kwargs): diff --git a/pdf2zh/gui.py b/pdf2zh/gui.py index 75ba5b67..1182d187 100644 --- a/pdf2zh/gui.py +++ b/pdf2zh/gui.py @@ -13,6 +13,7 @@ from pdf2zh import __version__ from pdf2zh.high_level import translate +from pdf2zh.pdf2zh import model from pdf2zh.translator import ( AnythingLLMTranslator, AzureOpenAITranslator, @@ -265,6 +266,7 @@ def progress_bar(t: tqdm.tqdm): "cancellation_event": cancellation_event_map[session_id], "envs": _envs, "prompt": prompt, + "model": model, } try: translate(**param) diff --git a/pdf2zh/pdf2zh.py b/pdf2zh/pdf2zh.py index c7b7810f..be682f2e 100644 --- a/pdf2zh/pdf2zh.py +++ b/pdf2zh/pdf2zh.py @@ -199,6 +199,9 @@ def find_all_files_in_directory(directory_path): return file_paths +model = None + + def main(args: Optional[List[str]] = None) -> int: logging.basicConfig() @@ -206,6 +209,11 @@ def main(args: Optional[List[str]] = None) -> int: if parsed_args.debug: log.setLevel(logging.DEBUG) + global model + if parsed_args.onnx: + model = OnnxModel(parsed_args.onnx) + else: + model = OnnxModel.load_available() if parsed_args.interactive: from pdf2zh.gui import setup_gui @@ -238,12 +246,6 @@ def main(args: Optional[List[str]] = None) -> int: except Exception: raise ValueError("prompt error.") - model = None - if parsed_args.onnx: - model = OnnxModel(parsed_args.onnx) - else: - model = OnnxModel.load_available() - if parsed_args.dir: untranlate_file = find_all_files_in_directory(parsed_args.files[0]) parsed_args.files = untranlate_file