Skip to content

Commit

Permalink
fix in collate_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Jan 6, 2024
1 parent 7f2600f commit 1b70baa
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions doctr/datasets/datasets/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,18 @@ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:

@staticmethod
def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]:
images, targets = zip(*samples)
# FIXME
# problems with some shape != 3
images, targets = [], []
for sample in samples:
if sample[0].shape[-1] == 3:
images.append(sample[0])
targets.append(sample[1])

# images, targets = zip(*samples)
images = tf.stack(images, axis=0)

return images, list(targets)
return images, targets


class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
Expand Down

0 comments on commit 1b70baa

Please sign in to comment.