From d8db405a4c5597314aceba7252715c2b11c8c5cf Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 6 Nov 2019 00:39:05 +0100 Subject: [PATCH 1/5] warn if passing raw images to single-channel models --- ocrd_calamari/recognize.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/ocrd_calamari/recognize.py b/ocrd_calamari/recognize.py index 92aa5a4..31a37e1 100644 --- a/ocrd_calamari/recognize.py +++ b/ocrd_calamari/recognize.py @@ -31,6 +31,14 @@ def _init_calamari(self): checkpoints = glob(self.parameter['checkpoint']) self.predictor = MultiPredictor(checkpoints=checkpoints) + self.input_channels = self.predictor.predictors[0].network.input_channels + #self.input_channels = self.predictor.predictors[0].network_params.channels # not used! + # binarization = self.predictor.predictors[0].model_params.data_preprocessor.binarization # not used! + # self.features = ('' if self.input_channels != 1 else + # 'binarized' if binarization != 'GRAY' else + # 'grayscale_normalized') + self.features = '' + voter_params = VoterParams() voter_params.type = VoterParams.Type.Value(self.parameter['voter'].upper()) self.voter = voter_from_proto(voter_params) @@ -54,17 +62,30 @@ def process(self): pcgts = page_from_file(self.workspace.download_file(input_file)) page = pcgts.get_Page() - page_image, page_xywh, page_image_info = self.workspace.image_from_page(page, page_id) + page_image, page_coords, page_image_info = self.workspace.image_from_page( + page, page_id, feature_selector=self.features) - for region in pcgts.get_Page().get_TextRegion(): - region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh) + for region in page.get_TextRegion(): + region_image, region_coords = self.workspace.image_from_segment( + region, page_image, page_coords, feature_selector=self.features) textlines = region.get_TextLine() log.info("About to recognize %i lines of region '%s'", len(textlines), region.id) - for (line_no, line) in enumerate(textlines): - log.debug("Recognizing line '%s' in region '%s'", line_no, region.id) - - line_image, line_xywh = self.workspace.image_from_segment(line, region_image, region_xywh) + for line in textlines: + log.debug("Recognizing line '%s' in region '%s'", line.id, region.id) + + line_image, line_coords = self.workspace.image_from_segment( + line, region_image, region_coords, feature_selector=self.features) + if ('binarized' not in line_coords['features'] and + 'grayscale_normalized' not in line_coords['features'] and + self.input_channels == 1): + # We cannot use a feature selector for this since we don't + # know whether the model expects (has been trained on) + # binarized or grayscale images; but raw images are likely + # always inadequate: + log.warning("Using raw image for line '%s' in region '%s'", + line.id, region.id) + line_image_np = np.array(line_image, dtype=np.uint8) raw_results = list(self.predictor.predict_raw([line_image_np], progress_bar=False))[0] From f20eb3ba459f4d80eeab743c11fabab5b2241750 Mon Sep 17 00:00:00 2001 From: "Gerber, Mike" Date: Thu, 5 Dec 2019 14:58:24 +0100 Subject: [PATCH 2/5] =?UTF-8?q?=F0=9F=8E=A8=20Refactor=20model=20path=20co?= =?UTF-8?q?nstant=20into=20a=20variable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_recognize.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_recognize.py b/test/test_recognize.py index 0fca48f..19932f4 100644 --- a/test/test_recognize.py +++ b/test/test_recognize.py @@ -10,6 +10,7 @@ from .base import assets METS_KANT = assets.url_of('kant_aufklaerung_1784-page-block-line-word_glyph/data/mets.xml') +CHECKPOINT = os.path.join(os.getcwd(), 'gt4histocr-calamari/*.ckpt.json') WORKSPACE_DIR = '/tmp/test-ocrd-calamari' @@ -52,9 +53,7 @@ def test_recognize(workspace): workspace, input_file_grp="OCR-D-GT-SEG-LINE", output_file_grp="OCR-D-OCR-CALAMARI", - parameter={ - 'checkpoint': os.path.join(os.getcwd(), 'gt4histocr-calamari/*.ckpt.json') - } + parameter={'checkpoint': CHECKPOINT} ).process() workspace.save_mets() From 377466a71ad76a73f8f1cd246a3484a478fc2296 Mon Sep 17 00:00:00 2001 From: "Gerber, Mike" Date: Thu, 5 Dec 2019 14:59:14 +0100 Subject: [PATCH 3/5] =?UTF-8?q?=E2=9C=85=20Add=20test=20to=20check=20if=20?= =?UTF-8?q?we=20warn=20when=20processing=20a=20"raw"/RGB=20image=20with=20?= =?UTF-8?q?a=20single-channel=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_recognize.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_recognize.py b/test/test_recognize.py index 19932f4..f97ef91 100644 --- a/test/test_recognize.py +++ b/test/test_recognize.py @@ -4,6 +4,7 @@ import urllib.request import pytest +import logging from ocrd.resolver import Resolver from ocrd_calamari import CalamariRecognize @@ -61,3 +62,16 @@ def test_recognize(workspace): assert os.path.exists(page1) with open(page1, 'r', encoding='utf-8') as f: assert 'verſchuldeten' in f.read() + + +def test_recognize_should_warn_if_given_rgb_image_and_single_channel_model(workspace, caplog): + caplog.set_level(logging.WARNING) + CalamariRecognize( + workspace, + input_file_grp="OCR-D-GT-SEG-LINE", + output_file_grp="OCR-D-OCR-CALAMARI-BROKEN", + parameter={'checkpoint': CHECKPOINT} + ).process() + + interesting_log_messages = [t[2] for t in caplog.record_tuples if "Using raw image" in t[2]] + assert len(interesting_log_messages) > 10 # For every line! From 4cf25b81195ca6e81e9a73f9d6123daafdccfd3e Mon Sep 17 00:00:00 2001 From: "Gerber, Mike" Date: Tue, 9 Feb 2021 18:20:46 +0100 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=8E=A8=20Rename=20input=5Fchannels=20?= =?UTF-8?q?variable=20to=20network=5Finput=5Fchannels?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ocrd_calamari/recognize.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ocrd_calamari/recognize.py b/ocrd_calamari/recognize.py index 5c6807e..817d6d5 100644 --- a/ocrd_calamari/recognize.py +++ b/ocrd_calamari/recognize.py @@ -48,10 +48,10 @@ def _init_calamari(self): checkpoints = glob(self.parameter['checkpoint']) self.predictor = MultiPredictor(checkpoints=checkpoints) - self.input_channels = self.predictor.predictors[0].network.input_channels - #self.input_channels = self.predictor.predictors[0].network_params.channels # not used! + self.network_input_channels = self.predictor.predictors[0].network.input_channels + #self.network_input_channels = self.predictor.predictors[0].network_params.channels # not used! # binarization = self.predictor.predictors[0].model_params.data_preprocessor.binarization # not used! - # self.features = ('' if self.input_channels != 1 else + # self.features = ('' if self.network_input_channels != 1 else # 'binarized' if binarization != 'GRAY' else # 'grayscale_normalized') self.features = '' @@ -91,7 +91,7 @@ def process(self): log.debug("Recognizing line '%s' in region '%s'", line.id, region.id) line_image, line_coords = self.workspace.image_from_segment(line, region_image, region_coords, feature_selector=self.features) - if ('binarized' not in line_coords['features'] and 'grayscale_normalized' not in line_coords['features'] and self.input_channels == 1): + if ('binarized' not in line_coords['features'] and 'grayscale_normalized' not in line_coords['features'] and self.network_input_channels == 1): # We cannot use a feature selector for this since we don't # know whether the model expects (has been trained on) # binarized or grayscale images; but raw images are likely From f4c0fe857036d1426b7a89a42d491e1749ac6f1d Mon Sep 17 00:00:00 2001 From: "Gerber, Mike" Date: Tue, 9 Feb 2021 18:29:49 +0100 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=90=9B=20Fix=20small=20merge=20merge?= =?UTF-8?q?=20error=20(text=20not=20checked=20in=20test=5Frecognize=5Fshou?= =?UTF-8?q?ld=5Fwarn=5Fif=5Fgiven=5Frgb=5Fimage=5Fand=5Fsingle=5Fchannel?= =?UTF-8?q?=5Fmodel)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_recognize.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_recognize.py b/test/test_recognize.py index b3e8540..fdb2679 100644 --- a/test/test_recognize.py +++ b/test/test_recognize.py @@ -95,8 +95,8 @@ def test_recognize_with_checkpoint_dir(workspace): page1 = os.path.join(workspace.directory, "OCR-D-OCR-CALAMARI/OCR-D-OCR-CALAMARI_0001.xml") assert os.path.exists(page1) - with open(page1, 'r', encoding='utf-8') as f: - assert 'verſchuldeten' in f.read() + with open(page1, "r", encoding="utf-8") as f: + assert "verſchuldeten" in f.read() def test_recognize_should_warn_if_given_rgb_image_and_single_channel_model(workspace, caplog): @@ -110,8 +110,6 @@ def test_recognize_should_warn_if_given_rgb_image_and_single_channel_model(works interesting_log_messages = [t[2] for t in caplog.record_tuples if "Using raw image" in t[2]] assert len(interesting_log_messages) > 10 # For every line! - with open(page1, "r", encoding="utf-8") as f: - assert "verſchuldeten" in f.read() def test_word_segmentation(workspace):