-
-
Notifications
You must be signed in to change notification settings - Fork 702
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #282 from YadominJinta/main
feat: add unit test
- Loading branch information
Showing
4 changed files
with
246 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import unittest | ||
from unittest.mock import Mock, patch, MagicMock | ||
from pdfminer.layout import LTPage, LTChar, LTLine | ||
from pdfminer.pdfinterp import PDFResourceManager | ||
from pdf2zh.converter import PDFConverterEx, TranslateConverter | ||
|
||
|
||
class TestPDFConverterEx(unittest.TestCase): | ||
def setUp(self): | ||
self.rsrcmgr = PDFResourceManager() | ||
self.converter = PDFConverterEx(self.rsrcmgr) | ||
|
||
def test_begin_page(self): | ||
mock_page = Mock() | ||
mock_page.pageno = 1 | ||
mock_page.cropbox = (0, 0, 100, 200) | ||
mock_ctm = [1, 0, 0, 1, 0, 0] | ||
self.converter.begin_page(mock_page, mock_ctm) | ||
self.assertIsNotNone(self.converter.cur_item) | ||
self.assertEqual(self.converter.cur_item.pageid, 1) | ||
|
||
def test_render_char(self): | ||
mock_matrix = (1, 2, 3, 4, 5, 6) | ||
mock_font = Mock() | ||
mock_font.to_unichr.return_value = "A" | ||
mock_font.char_width.return_value = 10 | ||
mock_font.char_disp.return_value = (0, 0) | ||
graphic_state = Mock() | ||
self.converter.cur_item = Mock() | ||
result = self.converter.render_char( | ||
mock_matrix, | ||
mock_font, | ||
fontsize=12, | ||
scaling=1.0, | ||
rise=0, | ||
cid=65, | ||
ncs=None, | ||
graphicstate=graphic_state, | ||
) | ||
self.assertEqual(result, 120.0) # Expected text width | ||
|
||
|
||
class TestTranslateConverter(unittest.TestCase): | ||
def setUp(self): | ||
self.rsrcmgr = PDFResourceManager() | ||
self.layout = {1: Mock()} | ||
self.translator_class = Mock() | ||
self.converter = TranslateConverter( | ||
self.rsrcmgr, | ||
layout=self.layout, | ||
lang_in="en", | ||
lang_out="zh", | ||
service="google", | ||
) | ||
|
||
def test_translator_initialization(self): | ||
self.assertIsNotNone(self.converter.translator) | ||
self.assertEqual(self.converter.translator.lang_in, "en") | ||
self.assertEqual(self.converter.translator.lang_out, "zh-CN") | ||
|
||
@patch("pdf2zh.converter.TranslateConverter.receive_layout") | ||
def test_receive_layout(self, mock_receive_layout): | ||
mock_page = LTPage(1, (0, 0, 100, 200)) | ||
mock_font = Mock() | ||
mock_font.fontname.return_value = "mock_font" | ||
mock_page.add( | ||
LTChar( | ||
matrix=(1, 2, 3, 4, 5, 6), | ||
font=mock_font, | ||
fontsize=12, | ||
scaling=1.0, | ||
rise=0, | ||
text="A", | ||
textwidth=10, | ||
textdisp=(1.0, 1.0), | ||
ncs=Mock(), | ||
graphicstate=Mock(), | ||
) | ||
) | ||
self.converter.receive_layout(mock_page) | ||
mock_receive_layout.assert_called_once_with(mock_page) | ||
|
||
@patch("concurrent.futures.ThreadPoolExecutor") | ||
@patch("pdf2zh.cache") | ||
def test_translation(self, mock_cache, mock_executor): | ||
mock_executor.return_value.__enter__.return_value.map.return_value = [ | ||
"你好", | ||
"{v1}", | ||
] | ||
mock_cache.deterministic_hash.return_value = "test_hash" | ||
mock_cache.load_paragraph.return_value = None | ||
mock_cache.write_paragraph.return_value = None | ||
|
||
sstk = ["Hello", "{v1}"] | ||
self.converter.thread = 2 | ||
results = [] | ||
with patch.object(self.converter, "translator") as mock_translator: | ||
mock_translator.translate.side_effect = lambda x: ( | ||
"你好" if x == "Hello" else x | ||
) | ||
for s in sstk: | ||
results.append(self.converter.translator.translate(s)) | ||
self.assertEqual(results, ["你好", "{v1}"]) | ||
|
||
def test_receive_layout_with_complex_formula(self): | ||
ltpage = LTPage(1, (0, 0, 500, 500)) | ||
ltchar = Mock() | ||
ltchar.fontname.return_value = "mock_font" | ||
ltline = LTLine(0.1, (0, 0), (10, 20)) | ||
ltpage.add(ltchar) | ||
ltpage.add(ltline) | ||
mock_layout = MagicMock() | ||
mock_layout.shape = (100, 100) | ||
mock_layout.__getitem__.return_value = -1 | ||
self.converter.layout = [None, mock_layout] | ||
self.converter.thread = 1 | ||
result = self.converter.receive_layout(ltpage) | ||
self.assertIsNotNone(result) | ||
|
||
def test_invalid_translation_service(self): | ||
with self.assertRaises(ValueError): | ||
TranslateConverter( | ||
self.rsrcmgr, | ||
layout=self.layout, | ||
lang_in="en", | ||
lang_out="zh", | ||
service="InvalidService", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
import numpy as np | ||
from pdf2zh.doclayout import ( | ||
OnnxModel, | ||
YoloResult, | ||
YoloBox, | ||
) | ||
|
||
|
||
class TestOnnxModel(unittest.TestCase): | ||
@patch("onnx.load") | ||
@patch("onnxruntime.InferenceSession") | ||
def setUp(self, mock_inference_session, mock_onnx_load): | ||
# Mock ONNX model metadata | ||
mock_model = MagicMock() | ||
mock_model.metadata_props = [ | ||
MagicMock(key="stride", value="32"), | ||
MagicMock(key="names", value="['class1', 'class2']"), | ||
] | ||
mock_onnx_load.return_value = mock_model | ||
|
||
# Initialize OnnxModel with a fake path | ||
self.model_path = "fake_model_path.onnx" | ||
self.model = OnnxModel(self.model_path) | ||
|
||
def test_stride_property(self): | ||
# Test that stride is correctly set from model metadata | ||
self.assertEqual(self.model.stride, 32) | ||
|
||
def test_resize_and_pad_image(self): | ||
# Create a dummy image (100x200) | ||
image = np.ones((100, 200, 3), dtype=np.uint8) | ||
resized_image = self.model.resize_and_pad_image(image, 1024) | ||
|
||
# Validate the output shape | ||
self.assertEqual(resized_image.shape[0], 512) | ||
self.assertEqual(resized_image.shape[1], 1024) | ||
|
||
# Check that padding has been added | ||
padded_height = resized_image.shape[0] - image.shape[0] | ||
padded_width = resized_image.shape[1] - image.shape[1] | ||
self.assertGreater(padded_height, 0) | ||
self.assertGreater(padded_width, 0) | ||
|
||
def test_scale_boxes(self): | ||
img1_shape = (1024, 1024) # Model input shape | ||
img0_shape = (500, 300) # Original image shape | ||
boxes = np.array([[512, 512, 768, 768]]) # Example bounding box | ||
|
||
scaled_boxes = self.model.scale_boxes(img1_shape, boxes, img0_shape) | ||
|
||
# Verify the output is scaled correctly | ||
self.assertEqual(scaled_boxes.shape, boxes.shape) | ||
self.assertTrue(np.all(scaled_boxes <= max(img0_shape))) | ||
|
||
def test_predict(self): | ||
# Mock model inference output | ||
mock_output = np.random.random((1, 300, 6)) | ||
self.model.model.run.return_value = [mock_output] | ||
|
||
# Create a dummy image | ||
image = np.ones((500, 300, 3), dtype=np.uint8) | ||
|
||
results = self.model.predict(image) | ||
|
||
# Validate predictions | ||
self.assertEqual(len(results), 1) | ||
self.assertIsInstance(results[0], YoloResult) | ||
self.assertGreater(len(results[0].boxes), 0) | ||
self.assertIsInstance(results[0].boxes[0], YoloBox) | ||
|
||
|
||
class TestYoloResult(unittest.TestCase): | ||
def test_yolo_result(self): | ||
# Example prediction data | ||
boxes = [ | ||
[100, 200, 300, 400, 0.9, 0], | ||
[50, 100, 150, 200, 0.8, 1], | ||
] | ||
names = ["class1", "class2"] | ||
|
||
result = YoloResult(boxes, names) | ||
|
||
# Validate the number of boxes and their order by confidence | ||
self.assertEqual(len(result.boxes), 2) | ||
self.assertGreater(result.boxes[0].conf, result.boxes[1].conf) | ||
self.assertEqual(result.names, names) | ||
|
||
|
||
class TestYoloBox(unittest.TestCase): | ||
def test_yolo_box(self): | ||
# Example box data | ||
box_data = [100, 200, 300, 400, 0.9, 0] | ||
|
||
box = YoloBox(box_data) | ||
|
||
# Validate box properties | ||
self.assertEqual(box.xyxy, box_data[:4]) | ||
self.assertEqual(box.conf, box_data[4]) | ||
self.assertEqual(box.cls, box_data[5]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |