Skip to content

Commit

Permalink
add try except around datasets to train on broken datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Jan 5, 2024
1 parent 7cc10d2 commit 7f2600f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 30 deletions.
54 changes: 32 additions & 22 deletions doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import shutil
import traceback
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -46,28 +47,37 @@ def _read_sample(self, index: int) -> Tuple[Any, Any]:

def __getitem__(self, index: int) -> Tuple[Any, Any]:
# Read image
img, target = self._read_sample(index)
# Pre-transforms (format conversion at run-time etc.)
if self._pre_transforms is not None:
img, target = self._pre_transforms(img, target)

if self.img_transforms is not None:
# typing issue cf. https://github.com/python/mypy/issues/5485
img = self.img_transforms(img)

if self.sample_transforms is not None:
# Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
if (
isinstance(target, dict)
and all(isinstance(item, np.ndarray) for item in target.values())
and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
):
img_transformed = _copy_tensor(img)
for class_name, bboxes in target.items():
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
img = img_transformed
else:
img, target = self.sample_transforms(img, target)
try:
img, target = self._read_sample(index)
# Pre-transforms (format conversion at run-time etc.)
if self._pre_transforms is not None:
img, target = self._pre_transforms(img, target)

if self.img_transforms is not None:
# typing issue cf. https://github.com/python/mypy/issues/5485
img = self.img_transforms(img)

if self.sample_transforms is not None:
# Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
if (
isinstance(target, dict)
and all(isinstance(item, np.ndarray) for item in target.values())
and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
):
img_transformed = _copy_tensor(img)
for class_name, bboxes in target.items():
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
img = img_transformed
else:
img, target = self.sample_transforms(img, target)
except Exception:
img_name = self.data[index][0]
# Write
print()
print(f"!!!ERROR in Dataset on filename {img_name}")
traceback.print_exc()
print()
return self.__getitem__(0) # should exists ^^

return img, target

Expand Down
16 changes: 11 additions & 5 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,20 @@ def __init__(

self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
np_dtype = np.float32
missing_files = []
for img_name, label in labels.items():
# File existence check
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)

self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
missing_files.append(img_name)
# raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
else:
geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
print("List of missing files:")
print(f"MISSING FILES: {len(missing_files)}")
from pprint import pprint

pprint(missing_files)

def format_polygons(
self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
Expand Down
13 changes: 10 additions & 3 deletions doctr/datasets/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,18 @@ def __init__(
with open(labels_path) as f:
labels = json.load(f)

missing_files = []
for img_name, label in labels.items():
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

self.data.append((img_name, label))
missing_files.append(img_name)
# raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
else:
self.data.append((img_name, label))
print("List of missing files:")
print(f"MISSING FILES: {len(missing_files)}")
from pprint import pprint

pprint(missing_files)

def merge_dataset(self, ds: AbstractDataset) -> None:
# Update data with new root for self
Expand Down

0 comments on commit 7f2600f

Please sign in to comment.