Skip to content

Commit

Permalink
Merge pull request #324 from awwaawwa/better_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Byaidu authored Dec 25, 2024
2 parents 80c4a1b + 97e9a90 commit a1eb6c8
Show file tree
Hide file tree
Showing 8 changed files with 497 additions and 236 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/python-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ name: Test and Build Python Package

on:
push:
branches:
- main
pull_request:

jobs:
Expand Down
227 changes: 137 additions & 90 deletions pdf2zh/cache.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,138 @@
import tempfile
import os
import time
import hashlib
import shutil

cache_dir = os.path.join(tempfile.gettempdir(), "cache")
os.makedirs(cache_dir, exist_ok=True)
time_filename = "update_time"
max_cache = 5


def deterministic_hash(obj):
hash_object = hashlib.sha256()
hash_object.update(str(obj).encode())
return hash_object.hexdigest()[0:20]


def get_dirs():
dirs = [
os.path.join(cache_dir, dir)
for dir in os.listdir(cache_dir)
if os.path.isdir(os.path.join(cache_dir, dir))
]
return dirs


def get_time(dir):
try:
timefile = os.path.join(dir, time_filename)
t = float(open(timefile, encoding="utf-8").read())
return t
except FileNotFoundError:
# handle the error as needed, for now we'll just return a default value
return float(
"inf"
) # This ensures that this directory will be the first to be removed if required


def write_time(dir):
timefile = os.path.join(dir, time_filename)
t = time.time()
print(t, file=open(timefile, "w", encoding="utf-8"), end="")


def argmin(iterable):
return min(enumerate(iterable), key=lambda x: x[1])[0]


def remove_extra():
dirs = get_dirs()
for dir in dirs:
if not os.path.isdir(
dir
): # This line might be redundant now, as get_dirs() ensures only directories are returned
os.remove(dir)
try:
get_time(dir)
except BaseException:
shutil.rmtree(dir)
while True:
dirs = get_dirs()
if len(dirs) <= max_cache:
break
times = [get_time(dir) for dir in dirs]
arg = argmin(times)
shutil.rmtree(dirs[arg])


def is_cached(hash_key):
dir = os.path.join(cache_dir, hash_key)
return os.path.exists(dir)


def create_cache(hash_key):
dir = os.path.join(cache_dir, hash_key)
os.makedirs(dir, exist_ok=True)
write_time(dir)


def load_paragraph(hash_key, hash_key_paragraph):
filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
if os.path.exists(filename):
return open(filename, encoding="utf-8").read()
else:
return None


def write_paragraph(hash_key, hash_key_paragraph, paragraph):
filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
print(paragraph, file=open(filename, "w", encoding="utf-8"), end="")
import json
from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL
from typing import Optional


# we don't init the database here
db = SqliteDatabase(None)


class _TranslationCache(Model):
id = AutoField()
translate_engine = CharField(max_length=20)
translate_engine_params = TextField()
original_text = TextField()
translation = TextField()

class Meta:
database = db
constraints = [
SQL(
"""
UNIQUE (
translate_engine,
translate_engine_params,
original_text
)
ON CONFLICT REPLACE
"""
)
]


class TranslationCache:
@staticmethod
def _sort_dict_recursively(obj):
if isinstance(obj, dict):
return {
k: TranslationCache._sort_dict_recursively(v)
for k in sorted(obj.keys())
for v in [obj[k]]
}
elif isinstance(obj, list):
return [TranslationCache._sort_dict_recursively(item) for item in obj]
return obj

def __init__(self, translate_engine: str, translate_engine_params: dict = None):
self.translate_engine = translate_engine
self.replace_params(translate_engine_params)

# The program typically starts multi-threaded translation
# only after cache parameters are fully configured,
# so thread safety doesn't need to be considered here.
def replace_params(self, params: dict = None):
if params is None:
params = {}
self.params = params
params = self._sort_dict_recursively(params)
self.translate_engine_params = json.dumps(params)

def update_params(self, params: dict = None):
if params is None:
params = {}
self.params.update(params)
self.replace_params(self.params)

def add_params(self, k: str, v):
self.params[k] = v
self.replace_params(self.params)

# Since peewee and the underlying sqlite are thread-safe,
# get and set operations don't need locks.
def get(self, original_text: str) -> Optional[str]:
result = _TranslationCache.get_or_none(
translate_engine=self.translate_engine,
translate_engine_params=self.translate_engine_params,
original_text=original_text,
)
return result.translation if result else None

def set(self, original_text: str, translation: str):
_TranslationCache.create(
translate_engine=self.translate_engine,
translate_engine_params=self.translate_engine_params,
original_text=original_text,
translation=translation,
)


def init_db(remove_exists=False):
cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh")
os.makedirs(cache_folder, exist_ok=True)
# The current version does not support database migration, so add the version number to the file name.
cache_db_path = os.path.join(cache_folder, "cache.v1.db")
if remove_exists and os.path.exists(cache_db_path):
os.remove(cache_db_path)
db.init(
cache_db_path,
pragmas={
"journal_mode": "wal",
"busy_timeout": 1000,
},
)
db.create_tables([_TranslationCache], safe=True)


def init_test_db():
import tempfile

cache_db_path = tempfile.mktemp(suffix=".db")
test_db = SqliteDatabase(
cache_db_path,
pragmas={
"journal_mode": "wal",
"busy_timeout": 1000,
},
)
test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False)
test_db.connect()
test_db.create_tables([_TranslationCache], safe=True)
return test_db


def clean_test_db(test_db):
test_db.drop_tables([_TranslationCache])
test_db.close()
db_path = test_db.database
if os.path.exists(db_path):
os.remove(test_db.database)
wal_path = db_path + "-wal"
if os.path.exists(wal_path):
os.remove(wal_path)
shm_path = db_path + "-shm"
if os.path.exists(shm_path):
os.remove(shm_path)


init_db()
11 changes: 1 addition & 10 deletions pdf2zh/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
import unicodedata
from tenacity import retry, wait_fixed
from pdf2zh import cache
from pdf2zh.translator import (
AzureOpenAITranslator,
BaseTranslator,
Expand Down Expand Up @@ -328,21 +327,13 @@ def vflag(font: str, char: str): # 匹配公式(和角标)字体
############################################################
# B. 段落翻译
log.debug("\n==========[SSTACK]==========\n")
hash_key = cache.deterministic_hash("PDFMathTranslate")
cache.create_cache(hash_key)

@retry(wait=wait_fixed(1))
def worker(s: str): # 多线程翻译
if not s.strip() or re.match(r"^\{v\d+\}$", s): # 空白和公式不翻译
return s
try:
hash_key_paragraph = cache.deterministic_hash(
(s, str(self.translator))
)
new = cache.load_paragraph(hash_key, hash_key_paragraph) # 查询缓存
if new is None:
new = self.translator.translate(s)
cache.write_paragraph(hash_key, hash_key_paragraph, new)
new = self.translator.translate(s)
return new
except BaseException as e:
if log.isEnabledFor(logging.DEBUG):
Expand Down
Loading

0 comments on commit a1eb6c8

Please sign in to comment.