Skip to content

Commit

Permalink
Add wildreceipt dataset (#1359)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamzaGbada authored Oct 27, 2023
1 parent 3fcf835 commit 7222fe8
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Supported datasets
* IMGUR5K from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example" <https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset>`_.
* MJSynth from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" <https://www.robots.ox.ac.uk/~vgg/data/text/>`_.
* IIITHWS from `"Generating Synthetic Data for Text Recognition" <https://github.com/kris314/hwnet>`_.
* WILDRECEIPT from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction" <https://arxiv.org/pdf/2103.14470v1.pdf>`_.


.. toctree::
Expand Down
4 changes: 4 additions & 0 deletions docs/source/using_doctr/using_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ This datasets contains the information to train or validate a text detection mod
+-----------------------------+---------------------------------+---------------------------------+----------------------------------+
| IMGUR5K | 7149 | 796 | Handwritten / external resources |
+-----------------------------+---------------------------------+---------------------------------+----------------------------------+
| WILDRECEIPT | 1268 | 472 | external resources |
+-----------------------------+---------------------------------+---------------------------------+----------------------------------+

.. code:: python3
Expand Down Expand Up @@ -84,6 +86,8 @@ This datasets contains the information to train or validate a text recognition m
+-----------------------------+---------------------------------+---------------------------------+---------------------------------------------+
| IIITHWS | 7141797 | 793533 | english / handwritten / external resources |
+-----------------------------+---------------------------------+---------------------------------+---------------------------------------------+
| WILDRECEIPT | 49377 | 19598 | english / external resources |
+-----------------------------+---------------------------------+---------------------------------+---------------------------------------------+

.. code:: python3
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .synthtext import *
from .utils import *
from .vocabs import *
from .wildreceipt import *

if is_tf_available():
from .loader import *
106 changes: 106 additions & 0 deletions doctr/datasets/wildreceipt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (C) 2021-2023, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union

import numpy as np

from .datasets import AbstractDataset
from .utils import convert_target_to_relative, crop_bboxes_from_image

__all__ = ["WILDRECEIPT"]


class WILDRECEIPT(AbstractDataset):
"""WildReceipt dataset from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction"
<https://arxiv.org/abs/2103.14470v1>`_ |
`repository <https://download.openmmlab.com/mmocr/data/wildreceipt.tar>`_.
>>> # NOTE: You need to download the dataset first.
>>> from doctr.datasets import WILDRECEIPT
>>> train_set = WILDRECEIPT(train=True, img_folder="/path/to/wildreceipt/",
>>> label_path="/path/to/wildreceipt/train.txt")
>>> img, target = train_set[0]
>>> test_set = WILDRECEIPT(train=False, img_folder="/path/to/wildreceipt/",
>>> label_path="/path/to/wildreceipt/test.txt")
>>> img, target = test_set[0]
Args:
img_folder: folder with all the images of the dataset
label_path: path to the annotations file of the dataset
train: whether the subset should be the training one
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
recognition_task: whether the dataset should be used for recognition task
**kwargs: keyword arguments from `AbstractDataset`.
"""

def __init__(
self,
img_folder: str,
label_path: str,
train: bool = True,
use_polygons: bool = False,
recognition_task: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
)
# File existence check
if not os.path.exists(label_path) or not os.path.exists(img_folder):
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")

tmp_root = img_folder
self.train = train
np_dtype = np.float32
self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []

with open(label_path, "r") as file:
data = file.read()
# Split the text file into separate JSON strings
json_strings = data.strip().split("\n")
box: Union[List[float], np.ndarray]
_targets = []
for json_string in json_strings:
json_data = json.loads(json_string)
img_path = json_data["file_name"]
annotations = json_data["annotations"]
for annotation in annotations:
coordinates = annotation["box"]
if use_polygons:
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
box = np.array(
[
[coordinates[0], coordinates[1]],
[coordinates[2], coordinates[3]],
[coordinates[4], coordinates[5]],
[coordinates[6], coordinates[7]],
],
dtype=np_dtype,
)
else:
x, y = coordinates[::2], coordinates[1::2]
box = [min(x), min(y), max(x), max(y)]
_targets.append((annotation["text"], box))
text_targets, box_targets = zip(*_targets)

if recognition_task:
crops = crop_bboxes_from_image(
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
)
for crop, label in zip(crops, list(text_targets)):
if label and " " not in label:
self.data.append((crop, label))
else:
self.data.append(
(img_path, dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)))
)
self.root = tmp_root

def extra_repr(self) -> str:
return f"train={self.train}"
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,45 @@ def mock_iiithws_dataset(tmpdir_factory, mock_image_stream):
with open(fn, "wb") as f:
f.write(file.getbuffer())
return str(root), str(label_file)


@pytest.fixture(scope="session")
def mock_wildreceipt_dataset(tmpdir_factory, mock_image_stream):
file = BytesIO(mock_image_stream)
root = tmpdir_factory.mktemp("datasets")
wildreceipt_root = root.mkdir("wildreceipt")
annotations_folder = wildreceipt_root
image_folder = wildreceipt_root.mkdir("image_files")

labels = {
"file_name": "Image_58/20/receipt_0.jpeg",
"height": 348,
"width": 348,
"annotations": [
{"box": [263.0, 283.0, 325.0, 283.0, 325.0, 260.0, 263.0, 260.0], "text": "$55.96", "label": 17},
{"box": [274.0, 308.0, 326.0, 308.0, 326.0, 286.0, 274.0, 286.0], "text": "$4.48", "label": 19},
],
}
labels2 = {
"file_name": "Image_58/20/receipt_1.jpeg",
"height": 348,
"width": 348,
"annotations": [
{"box": [386.0, 409.0, 599.0, 409.0, 599.0, 373.0, 386.0, 373.0], "text": "089-46169340", "label": 5}
],
}

annotation_file = annotations_folder.join("train.txt")
with open(annotation_file, "w") as f:
json.dump(labels, f)
f.write("\n")
json.dump(labels2, f)
f.write("\n")
file = BytesIO(mock_image_stream)
wildreceipt_image_folder = image_folder.mkdir("Image_58")
wildreceipt_image_folder = wildreceipt_image_folder.mkdir("20")
for i in range(2):
fn_i = wildreceipt_image_folder.join(f"receipt_{i}.jpeg")
with open(fn_i, "wb") as f:
f.write(file.getbuffer())
return str(image_folder), str(annotation_file)
26 changes: 26 additions & 0 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False):
# Fetch one sample
img, target = ds[0]

assert isinstance(img, torch.Tensor)
assert img.shape == (3, *input_size)
assert img.dtype == torch.float32
Expand Down Expand Up @@ -49,6 +50,7 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly
def _validate_dataset_recognition_part(ds, input_size, batch_size=2):
# Fetch one sample
img, label = ds[0]

assert isinstance(img, torch.Tensor)
assert img.shape == (3, *input_size)
assert img.dtype == torch.float32
Expand Down Expand Up @@ -551,6 +553,30 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset):
_validate_dataset(ds, input_size, is_polygons=rotate)


@pytest.mark.parametrize("rotate", [True, False])
@pytest.mark.parametrize(
"input_size, num_samples, recognition",
[
[[512, 512], 2, False],
[[32, 128], 5, True],
],
)
def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset):
ds = datasets.WILDRECEIPT(
*mock_wildreceipt_dataset,
train=True,
img_transforms=Resize(input_size),
use_polygons=rotate,
recognition_task=recognition,
)
assert len(ds) == num_samples
assert repr(ds) == f"WILDRECEIPT(train={True})"
if recognition:
_validate_dataset_recognition_part(ds, input_size)
else:
_validate_dataset(ds, input_size, is_polygons=rotate)


# NOTE: following datasets are only for recognition task


Expand Down
24 changes: 24 additions & 0 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,30 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset):
_validate_dataset(ds, input_size, is_polygons=rotate)


@pytest.mark.parametrize("rotate", [True, False])
@pytest.mark.parametrize(
"input_size, num_samples, recognition",
[
[[512, 512], 2, False],
[[32, 128], 5, True],
],
)
def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset):
ds = datasets.WILDRECEIPT(
*mock_wildreceipt_dataset,
train=True,
img_transforms=Resize(input_size),
use_polygons=rotate,
recognition_task=recognition,
)
assert len(ds) == num_samples
assert repr(ds) == f"WILDRECEIPT(train={True})"
if recognition:
_validate_dataset_recognition_part(ds, input_size)
else:
_validate_dataset(ds, input_size, is_polygons=rotate)


# NOTE: following datasets are only for recognition task


Expand Down

0 comments on commit 7222fe8

Please sign in to comment.