Skip to content

Commit

Permalink
✅ test(translator): add unit tests for translation caching
Browse files Browse the repository at this point in the history
- implement unit tests for translation caching functionalities
- ensure cache interacts correctly with different input cases
- verify cache behavior with overridden translations and parameters
  • Loading branch information
awwaawwa committed Dec 23, 2024
1 parent ddf5426 commit 4cbacd2
Showing 1 changed file with 55 additions and 200 deletions.
255 changes: 55 additions & 200 deletions test/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,213 +1,68 @@
import unittest
from pdf2zh.translator import BaseTranslator
from pdf2zh import cache
import threading
import multiprocessing
import random
import string


class TestCache(unittest.TestCase):
class AutoIncreaseTranslator(BaseTranslator):
name = "auto_increase"
n = 0

def do_translate(self, text):
self.n += 1
return str(self.n)


class TestTranslator(unittest.TestCase):
def setUp(self):
self.test_db = cache.init_test_db()

def tearDown(self):
# Clean up
cache.clean_test_db(self.test_db)

def test_basic_set_get(self):
"""Test basic set and get operations"""
cache_instance = cache.TranslationCache("test_engine")

# Test get with non-existent entry
result = cache_instance.get("hello")
self.assertIsNone(result)

# Test set and get
cache_instance.set("hello", "你好")
result = cache_instance.get("hello")
self.assertEqual(result, "你好")

def test_cache_overwrite(self):
"""Test that cache entries can be overwritten"""
cache_instance = cache.TranslationCache("test_engine")

# Set initial translation
cache_instance.set("hello", "你好")

# Overwrite with new translation
cache_instance.set("hello", "您好")

# Verify the new translation is returned
result = cache_instance.get("hello")
self.assertEqual(result, "您好")

def test_non_string_params(self):
"""Test that non-string parameters are automatically converted to JSON"""
params = {"model": "gpt-3.5", "temperature": 0.7}
cache_instance = cache.TranslationCache("test_engine", params)

# Test that params are converted to JSON string internally
cache_instance.set("hello", "你好")
result = cache_instance.get("hello")
self.assertEqual(result, "你好")

# Test with different param types
array_params = ["param1", "param2"]
cache_instance2 = cache.TranslationCache("test_engine", array_params)
cache_instance2.set("hello", "你好2")
self.assertEqual(cache_instance2.get("hello"), "你好2")

# Test with nested structures
nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
cache_instance3 = cache.TranslationCache("test_engine", nested_params)
cache_instance3.set("hello", "你好3")
self.assertEqual(cache_instance3.get("hello"), "你好3")

def test_engine_distinction(self):
"""Test that cache distinguishes between different translation engines"""
cache1 = cache.TranslationCache("engine1")
cache2 = cache.TranslationCache("engine2")

# Set same text with different engines
cache1.set("hello", "你好 1")
cache2.set("hello", "你好 2")

# Verify each engine gets its own translation
self.assertEqual(cache1.get("hello"), "你好 1")
self.assertEqual(cache2.get("hello"), "你好 2")

def test_params_distinction(self):
"""Test that cache distinguishes between different engine parameters"""
params1 = {"param": "value1"}
params2 = {"param": "value2"}
cache1 = cache.TranslationCache("test_engine", params1)
cache2 = cache.TranslationCache("test_engine", params2)

# Set same text with different parameters
cache1.set("hello", "你好 1")
cache2.set("hello", "你好 2")

# Verify each parameter set gets its own translation
self.assertEqual(cache1.get("hello"), "你好 1")
self.assertEqual(cache2.get("hello"), "你好 2")

def test_consistent_param_serialization(self):
"""Test that dictionary parameters are consistently serialized regardless of key order"""
# Test simple dictionary
params1 = {"b": 1, "a": 2}
params2 = {"a": 2, "b": 1}
cache1 = cache.TranslationCache("test_engine", params1)
cache2 = cache.TranslationCache("test_engine", params2)
self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)

# Test nested dictionary
params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3}
params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}}
cache1 = cache.TranslationCache("test_engine", params1)
cache2 = cache.TranslationCache("test_engine", params2)
self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)

# Test dictionary with list of dictionaries
params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
cache1 = cache.TranslationCache("test_engine", params1)
cache2 = cache.TranslationCache("test_engine", params2)
self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)

# Test that different values still produce different results
params1 = {"a": 1, "b": 2}
params2 = {"a": 2, "b": 1}
cache1 = cache.TranslationCache("test_engine", params1)
cache2 = cache.TranslationCache("test_engine", params2)
self.assertNotEqual(
cache1.translate_engine_params, cache2.translate_engine_params
)

def test_cache_with_sorted_params(self):
"""Test that cache works correctly with sorted parameters"""
params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}

# Both caches should work with the same key
cache1 = cache.TranslationCache("test_engine", params1)
cache1.set("hello", "你好")

cache2 = cache.TranslationCache("test_engine", params2)
self.assertEqual(cache2.get("hello"), "你好")

def test_append_params(self):
"""Test the append_params method"""
cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})

# Test appending new parameter
cache_instance.add_params("new_param", "new_value")
self.assertEqual(
cache_instance.params, {"initial": "value", "new_param": "new_value"}
)

# Test that cache with appended params works correctly
cache_instance.set("hello", "你好")
self.assertEqual(cache_instance.get("hello"), "你好")

# Test overwriting existing parameter
cache_instance.add_params("initial", "new_value")
self.assertEqual(
cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
)

# Cache should work with updated params
cache_instance.set("hello2", "你好2")
self.assertEqual(cache_instance.get("hello2"), "你好2")

def test_thread_safety(self):
"""Test thread safety of cache operations"""
cache_instance = cache.TranslationCache("test_engine")
lock = threading.Lock()
results = []
num_threads = multiprocessing.cpu_count()
items_per_thread = 100

def generate_random_text(length=10):
return "".join(
random.choices(string.ascii_letters + string.digits, k=length)
)

def worker():
thread_results = [] # 线程本地存储结果
for _ in range(items_per_thread):
text = generate_random_text()
translation = f"翻译_{text}"

# Write operation
cache_instance.set(text, translation)

# Read operation - verify our own write
result = cache_instance.get(text)
thread_results.append((text, result))

# 所有操作完成后,一次性加锁并追加结果
with lock:
results.extend(thread_results)

# Create threads equal to CPU core count
threads = []
for _ in range(num_threads):
thread = threading.Thread(target=worker)
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join()

# Verify all operations were successful
expected_total = num_threads * items_per_thread
self.assertEqual(len(results), expected_total)

# Verify each thread got its correct value
for text, result in results:
expected = f"翻译_{text}"
self.assertEqual(result, expected)
def test_cache(self):
translator = AutoIncreaseTranslator("en", "zh", "test")
# First translation should be cached
text = "Hello World"
first_result = translator.translate(text)

# Second translation should return the same result from cache
second_result = translator.translate(text)
self.assertEqual(first_result, second_result)

# Different input should give different result
different_text = "Different Text"
different_result = translator.translate(different_text)
self.assertNotEqual(first_result, different_result)

# Test cache with ignore_cache=True
translator.ignore_cache = True
no_cache_result = translator.translate(text)
self.assertNotEqual(first_result, no_cache_result)

def test_add_cache_impact_parameters(self):
translator = AutoIncreaseTranslator("en", "zh", "test")

# Test cache with added parameters
text = "Hello World"
first_result = translator.translate(text)
translator.add_cache_impact_parameters("test", "value")
second_result = translator.translate(text)
self.assertNotEqual(first_result, second_result)

# Test cache with ignore_cache=True
translator.ignore_cache = True
no_cache_result = translator.translate(text)
self.assertNotEqual(first_result, no_cache_result)

# Test cache with ignore_cache=False
translator.ignore_cache = False
cache_result = translator.translate(text)
self.assertEqual(second_result, cache_result)

# Test cache with another parameter
translator.add_cache_impact_parameters("test2", "value2")
another_result = translator.translate(text)
self.assertNotEqual(second_result, another_result)


if __name__ == "__main__":
Expand Down

0 comments on commit 4cbacd2

Please sign in to comment.