From 4cbacd2c5d962d1d5c2ac403d2edfea2fd7ccb30 Mon Sep 17 00:00:00 2001 From: awwaawwa <8493196+awwaawwa@users.noreply.github.com> Date: Mon, 23 Dec 2024 23:22:25 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20test(translator):=20add=20unit=20te?= =?UTF-8?q?sts=20for=20translation=20caching?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - implement unit tests for translation caching functionalities - ensure cache interacts correctly with different input cases - verify cache behavior with overridden translations and parameters --- test/test_translator.py | 255 +++++++++------------------------------- 1 file changed, 55 insertions(+), 200 deletions(-) diff --git a/test/test_translator.py b/test/test_translator.py index 7aed3637..6ce3753d 100644 --- a/test/test_translator.py +++ b/test/test_translator.py @@ -1,213 +1,68 @@ import unittest +from pdf2zh.translator import BaseTranslator from pdf2zh import cache -import threading -import multiprocessing -import random -import string -class TestCache(unittest.TestCase): +class AutoIncreaseTranslator(BaseTranslator): + name = "auto_increase" + n = 0 + + def do_translate(self, text): + self.n += 1 + return str(self.n) + + +class TestTranslator(unittest.TestCase): def setUp(self): self.test_db = cache.init_test_db() def tearDown(self): - # Clean up cache.clean_test_db(self.test_db) - def test_basic_set_get(self): - """Test basic set and get operations""" - cache_instance = cache.TranslationCache("test_engine") - - # Test get with non-existent entry - result = cache_instance.get("hello") - self.assertIsNone(result) - - # Test set and get - cache_instance.set("hello", "你好") - result = cache_instance.get("hello") - self.assertEqual(result, "你好") - - def test_cache_overwrite(self): - """Test that cache entries can be overwritten""" - cache_instance = cache.TranslationCache("test_engine") - - # Set initial translation - cache_instance.set("hello", "你好") - - # Overwrite with new translation - cache_instance.set("hello", "您好") - - # Verify the new translation is returned - result = cache_instance.get("hello") - self.assertEqual(result, "您好") - - def test_non_string_params(self): - """Test that non-string parameters are automatically converted to JSON""" - params = {"model": "gpt-3.5", "temperature": 0.7} - cache_instance = cache.TranslationCache("test_engine", params) - - # Test that params are converted to JSON string internally - cache_instance.set("hello", "你好") - result = cache_instance.get("hello") - self.assertEqual(result, "你好") - - # Test with different param types - array_params = ["param1", "param2"] - cache_instance2 = cache.TranslationCache("test_engine", array_params) - cache_instance2.set("hello", "你好2") - self.assertEqual(cache_instance2.get("hello"), "你好2") - - # Test with nested structures - nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}} - cache_instance3 = cache.TranslationCache("test_engine", nested_params) - cache_instance3.set("hello", "你好3") - self.assertEqual(cache_instance3.get("hello"), "你好3") - - def test_engine_distinction(self): - """Test that cache distinguishes between different translation engines""" - cache1 = cache.TranslationCache("engine1") - cache2 = cache.TranslationCache("engine2") - - # Set same text with different engines - cache1.set("hello", "你好 1") - cache2.set("hello", "你好 2") - - # Verify each engine gets its own translation - self.assertEqual(cache1.get("hello"), "你好 1") - self.assertEqual(cache2.get("hello"), "你好 2") - - def test_params_distinction(self): - """Test that cache distinguishes between different engine parameters""" - params1 = {"param": "value1"} - params2 = {"param": "value2"} - cache1 = cache.TranslationCache("test_engine", params1) - cache2 = cache.TranslationCache("test_engine", params2) - - # Set same text with different parameters - cache1.set("hello", "你好 1") - cache2.set("hello", "你好 2") - - # Verify each parameter set gets its own translation - self.assertEqual(cache1.get("hello"), "你好 1") - self.assertEqual(cache2.get("hello"), "你好 2") - - def test_consistent_param_serialization(self): - """Test that dictionary parameters are consistently serialized regardless of key order""" - # Test simple dictionary - params1 = {"b": 1, "a": 2} - params2 = {"a": 2, "b": 1} - cache1 = cache.TranslationCache("test_engine", params1) - cache2 = cache.TranslationCache("test_engine", params2) - self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) - - # Test nested dictionary - params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3} - params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}} - cache1 = cache.TranslationCache("test_engine", params1) - cache2 = cache.TranslationCache("test_engine", params2) - self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) - - # Test dictionary with list of dictionaries - params1 = {"b": [{"y": 1, "x": 2}], "a": 3} - params2 = {"a": 3, "b": [{"x": 2, "y": 1}]} - cache1 = cache.TranslationCache("test_engine", params1) - cache2 = cache.TranslationCache("test_engine", params2) - self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) - - # Test that different values still produce different results - params1 = {"a": 1, "b": 2} - params2 = {"a": 2, "b": 1} - cache1 = cache.TranslationCache("test_engine", params1) - cache2 = cache.TranslationCache("test_engine", params2) - self.assertNotEqual( - cache1.translate_engine_params, cache2.translate_engine_params - ) - - def test_cache_with_sorted_params(self): - """Test that cache works correctly with sorted parameters""" - params1 = {"b": [{"y": 1, "x": 2}], "a": 3} - params2 = {"a": 3, "b": [{"x": 2, "y": 1}]} - - # Both caches should work with the same key - cache1 = cache.TranslationCache("test_engine", params1) - cache1.set("hello", "你好") - - cache2 = cache.TranslationCache("test_engine", params2) - self.assertEqual(cache2.get("hello"), "你好") - - def test_append_params(self): - """Test the append_params method""" - cache_instance = cache.TranslationCache("test_engine", {"initial": "value"}) - - # Test appending new parameter - cache_instance.add_params("new_param", "new_value") - self.assertEqual( - cache_instance.params, {"initial": "value", "new_param": "new_value"} - ) - - # Test that cache with appended params works correctly - cache_instance.set("hello", "你好") - self.assertEqual(cache_instance.get("hello"), "你好") - - # Test overwriting existing parameter - cache_instance.add_params("initial", "new_value") - self.assertEqual( - cache_instance.params, {"initial": "new_value", "new_param": "new_value"} - ) - - # Cache should work with updated params - cache_instance.set("hello2", "你好2") - self.assertEqual(cache_instance.get("hello2"), "你好2") - - def test_thread_safety(self): - """Test thread safety of cache operations""" - cache_instance = cache.TranslationCache("test_engine") - lock = threading.Lock() - results = [] - num_threads = multiprocessing.cpu_count() - items_per_thread = 100 - - def generate_random_text(length=10): - return "".join( - random.choices(string.ascii_letters + string.digits, k=length) - ) - - def worker(): - thread_results = [] # 线程本地存储结果 - for _ in range(items_per_thread): - text = generate_random_text() - translation = f"翻译_{text}" - - # Write operation - cache_instance.set(text, translation) - - # Read operation - verify our own write - result = cache_instance.get(text) - thread_results.append((text, result)) - - # 所有操作完成后,一次性加锁并追加结果 - with lock: - results.extend(thread_results) - - # Create threads equal to CPU core count - threads = [] - for _ in range(num_threads): - thread = threading.Thread(target=worker) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify all operations were successful - expected_total = num_threads * items_per_thread - self.assertEqual(len(results), expected_total) - - # Verify each thread got its correct value - for text, result in results: - expected = f"翻译_{text}" - self.assertEqual(result, expected) + def test_cache(self): + translator = AutoIncreaseTranslator("en", "zh", "test") + # First translation should be cached + text = "Hello World" + first_result = translator.translate(text) + + # Second translation should return the same result from cache + second_result = translator.translate(text) + self.assertEqual(first_result, second_result) + + # Different input should give different result + different_text = "Different Text" + different_result = translator.translate(different_text) + self.assertNotEqual(first_result, different_result) + + # Test cache with ignore_cache=True + translator.ignore_cache = True + no_cache_result = translator.translate(text) + self.assertNotEqual(first_result, no_cache_result) + + def test_add_cache_impact_parameters(self): + translator = AutoIncreaseTranslator("en", "zh", "test") + + # Test cache with added parameters + text = "Hello World" + first_result = translator.translate(text) + translator.add_cache_impact_parameters("test", "value") + second_result = translator.translate(text) + self.assertNotEqual(first_result, second_result) + + # Test cache with ignore_cache=True + translator.ignore_cache = True + no_cache_result = translator.translate(text) + self.assertNotEqual(first_result, no_cache_result) + + # Test cache with ignore_cache=False + translator.ignore_cache = False + cache_result = translator.translate(text) + self.assertEqual(second_result, cache_result) + + # Test cache with another parameter + translator.add_cache_impact_parameters("test2", "value2") + another_result = translator.translate(text) + self.assertNotEqual(second_result, another_result) if __name__ == "__main__":