-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdeepearthgo_image.py
795 lines (637 loc) · 24.7 KB
/
deepearthgo_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
# pylint: disable=no-member, too-many-lines, redefined-builtin, protected-access, unused-import, invalid-name
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches, too-many-statements
"""Read individual image files and perform augmentations."""
from __future__ import absolute_import, print_function
import os
import random
import logging
import numpy as np
try:
import cv2
except ImportError:
cv2 = None
from .base import numeric_types
from . import ndarray as nd
from . import _ndarray_internal as _internal
from ._ndarray_internal import _cvimresize as imresize
from ._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder
from . import io
from . import recordio
def imdecode(buf, **kwargs):
"""Decode an image to an NDArray.
Note: `imdecode` uses OpenCV (not the CV2 Python library).
MXNet must have been built with OpenCV for `imdecode` to work.
Parameters
----------
buf : str/bytes or numpy.ndarray
Binary image data as string or numpy ndarray.
flag : int, optional, default=1
1 for three channel color output. 0 for grayscale output.
to_rgb : int, optional, default=1
1 for RGB formatted output (MXNet default). 0 for BGR formatted output (OpenCV default).
out : NDArray, optional
Output buffer. Use `None` for automatic allocation.
Returns
-------
NDArray
An `NDArray` containing the image.
Example
-------
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image)
>>> image
<NDArray 224x224x3 @cpu(0)>
Set `flag` parameter to 0 to get grayscale output
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image, flag=0)
>>> image
<NDArray 224x224x1 @cpu(0)>
Set `to_rgb` parameter to 0 to get output in OpenCV format (BGR)
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image, to_rgb=0)
>>> image
<NDArray 224x224x3 @cpu(0)>
"""
if not isinstance(buf, nd.NDArray):
buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8)
return _internal._cvimdecode(buf, **kwargs)
def scale_down(src_size, size):
"""Scales down crop size if it's larger than image size.
If width/height of the crop is larger than the width/height of the image,
sets the width/height to the width/height of the image.
Parameters
----------
src_size : tuple of int
Size of the image in (width, height) format.
size : tuple of int
Size of the crop in (width, height) format.
Returns
-------
tuple of int
A tuple containing the scaled crop size in (width, height) format.
Example
--------
>>> src_size = (640,480)
>>> size = (720,120)
>>> new_size = mx.img.scale_down(src_size, size)
>>> new_size
(640,106)
"""
w, h = size
sw, sh = src_size
if sh < h:
w, h = float(w * sh) / h, sh
if sw < w:
w, h = sw, float(h * sw) / w
return int(w), int(h)
def resize_short(src, size, interp=2):
"""Resizes shorter edge to size.
Note: `resize_short` uses OpenCV (not the CV2 Python library).
MXNet must have been built with OpenCV for `resize_short` to work.
Resizes the original image by setting the shorter edge to size
and setting the longer edge accordingly.
Resizing function is called from OpenCV.
Parameters
----------
src : NDArray
The original image.
size : int
The length to be set for the shorter edge.
interp : int, optional, default=2
Interpolation method used for resizing the image.
Default method is bicubic interpolation.
More details can be found in the documentation of OpenCV, please refer to
http://docs.opencv.org/master/da/d54/group__imgproc__transform.html.
Returns
-------
NDArray
An 'NDArray' containing the resized image.
Example
-------
>>> with open("flower.jpeg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image)
>>> image
<NDArray 2321x3482x3 @cpu(0)>
>>> size = 640
>>> new_image = mx.img.resize_short(image, size)
>>> new_image
<NDArray 2321x3482x3 @cpu(0)>
"""
h, w, _ = src.shape
if h > w:
new_h, new_w = size * h / w, size
else:
new_h, new_w = size, size * w / h
return imresize(src, new_w, new_h, interp=interp)
def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
"""Crop src at fixed location, and (optionally) resize it to size."""
out = nd.crop(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2])))
if size is not None and (w, h) != size:
out = imresize(out, *size, interp=interp)
return out
def random_crop(src, size, interp=2):
"""Randomly crop `src` with `size` (width, height).
Upsample result if `src` is smaller than `size`.
Parameters
----------
src: Source image `NDArray`
size: Size of the crop formatted as (width, height). If the `size` is larger
than the image, then the source image is upsampled to `size` and returned.
interp: Interpolation method to be used in case the size is larger (default: bicubic).
Uses OpenCV convention for the parameters. Nearest - 0, Bilinear - 1, Bicubic - 2,
Area - 3. See OpenCV imresize function for more details.
Returns
-------
NDArray
An `NDArray` containing the cropped image.
Tuple
A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the
original image and (width, height) are the dimensions of the cropped image.
Example
-------
>>> im = mx.nd.array(cv2.imread("flower.jpg"))
>>> cropped_im, rect = mx.image.random_crop(im, (100, 100))
>>> print cropped_im
<NDArray 100x100x1 @cpu(0)>
>>> print rect
(20, 21, 100, 100)
"""
h, w, _ = src.shape
new_w, new_h = scale_down((w, h), size)
x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)
out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)
def center_crop(src, size, interp=2):
"""Crops the image `src` to the given `size` by trimming on all four
sides and preserving the center of the image. Upsamples if `src` is smaller
than `size`.
.. note:: This requires MXNet to be compiled with USE_OPENCV.
Parameters
----------
src : NDArray
Binary source image data.
size : list or tuple of int
The desired output image size.
interp : interpolation, optional, default=Area-based
The type of interpolation that is done to the image.
Possible values:
0: Nearest Neighbors Interpolation.
1: Bilinear interpolation.
2: Area-based (resampling using pixel area relation). It may be a
preferred method for image decimation, as it gives moire-free
results. But when the image is zoomed, it is similar to the Nearest
Neighbors method. (used by default).
3: Bicubic interpolation over 4x4 pixel neighborhood.
4: Lanczos interpolation over 8x8 pixel neighborhood.
When shrinking an image, it will generally look best with AREA-based
interpolation, whereas, when enlarging an image, it will generally look best
with Bicubic (slow) or Bilinear (faster but still looks OK).
Returns
-------
NDArray
The cropped image.
Tuple
(x, y, width, height) where x, y are the positions of the crop in the
original image and width, height the dimensions of the crop.
Example
-------
>>> with open("flower.jpg", 'rb') as fp:
... str_image = fp.read()
...
>>> image = mx.image.imdecode(str_image)
>>> image
<NDArray 2321x3482x3 @cpu(0)>
>>> cropped_image, (x, y, width, height) = mx.image.center_crop(image, (1000, 500))
>>> cropped_image
<NDArray 500x1000x3 @cpu(0)>
>>> x, y, width, height
(1241, 910, 1000, 500)
"""
h, w, _ = src.shape
new_w, new_h = scale_down((w, h), size)
x0 = int((w - new_w) / 2)
y0 = int((h - new_h) / 2)
out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)
def color_normalize(src, mean, std=None):
"""Normalize src with mean and std."""
src -= mean
if std is not None:
src /= std
return src
def random_size_crop(src, size, min_area, ratio, interp=2):
"""Randomly crop src with size. Randomize area and aspect ratio."""
h, w, _ = src.shape
new_ratio = random.uniform(*ratio)
if new_ratio * h > w:
max_area = w * int(w / new_ratio)
else:
max_area = h * int(h * new_ratio)
min_area *= h * w
if max_area < min_area:
return random_crop(src, size, interp)
new_area = random.uniform(min_area, max_area)
new_w = int(np.sqrt(new_area * new_ratio))
new_h = int(np.sqrt(new_area / new_ratio))
assert new_w <= w and new_h <= h
x0 = random.randint(0, w - new_w)
y0 = random.randint(0, h - new_h)
out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
return out, (x0, y0, new_w, new_h)
def ResizeAug(size, interp=2):
"""Make resize shorter edge to size augmenter."""
def aug(src):
"""Augmenter body"""
return [resize_short(src, size, interp)]
return aug
def RandomCropAug(size, interp=2):
"""Make random crop augmenter"""
def aug(src):
"""Augmenter body"""
return [random_crop(src, size, interp)[0]]
return aug
def RandomSizedCropAug(size, min_area, ratio, interp=2):
"""Make random crop with random resizing and random aspect ratio jitter augmenter."""
def aug(src):
"""Augmenter body"""
return [random_size_crop(src, size, min_area, ratio, interp)[0]]
return aug
def CenterCropAug(size, interp=2):
"""Make center crop augmenter."""
def aug(src):
"""Augmenter body"""
return [center_crop(src, size, interp)[0]]
return aug
def RandomOrderAug(ts):
"""Apply list of augmenters in random order"""
def aug(src):
"""Augmenter body"""
src = [src]
random.shuffle(ts)
for t in ts:
src = [j for i in src for j in t(i)]
return src
return aug
def ColorJitterAug(brightness, contrast, saturation):
"""Apply random brightness, contrast and saturation jitter in random order."""
ts = []
coef = nd.array([[[0.299, 0.587, 0.114]]])
if brightness > 0:
def baug(src):
"""Augmenter body"""
alpha = 1.0 + random.uniform(-brightness, brightness)
src *= alpha
return [src]
ts.append(baug)
if contrast > 0:
def caug(src):
"""Augmenter body"""
alpha = 1.0 + random.uniform(-contrast, contrast)
gray = src * coef
gray = (3.0 * (1.0 - alpha) / gray.size) * nd.sum(gray)
src *= alpha
src += gray
return [src]
ts.append(caug)
if saturation > 0:
def saug(src):
"""Augmenter body"""
alpha = 1.0 + random.uniform(-saturation, saturation)
gray = src * coef
gray = nd.sum(gray, axis=2, keepdims=True)
gray *= (1.0 - alpha)
src *= alpha
src += gray
return [src]
ts.append(saug)
return RandomOrderAug(ts)
def LightingAug(alphastd, eigval, eigvec):
"""Add PCA based noise."""
def aug(src):
"""Augmenter body"""
alpha = np.random.normal(0, alphastd, size=(3,))
rgb = np.dot(eigvec * alpha, eigval)
src += nd.array(rgb)
return [src]
return aug
def ColorNormalizeAug(mean, std):
"""Mean and std normalization."""
mean = nd.array(mean)
std = nd.array(std)
def aug(src):
"""Augmenter body"""
return [color_normalize(src, mean, std)]
return aug
def HorizontalFlipAug(p):
"""Random horizontal flipping."""
def aug(src):
"""Augmenter body"""
if random.random() < p:
src = nd.flip(src, axis=1)
return [src]
return aug
def CastAug():
"""Cast to float32"""
def aug(src):
"""Augmenter body"""
src = src.astype(np.float32)
return [src]
return aug
def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, rand_mirror=False,
mean=None, std=None, brightness=0, contrast=0, saturation=0,
pca_noise=0, inter_method=2):
"""Creates an augmenter list."""
auglist = []
if resize > 0:
auglist.append(ResizeAug(resize, inter_method))
crop_size = (data_shape[2], data_shape[1])
if rand_resize:
assert rand_crop
auglist.append(RandomSizedCropAug(crop_size, 0.3, (3.0 / 4.0, 4.0 / 3.0), inter_method))
elif rand_crop:
auglist.append(RandomCropAug(crop_size, inter_method))
else:
auglist.append(CenterCropAug(crop_size, inter_method))
if rand_mirror:
auglist.append(HorizontalFlipAug(0.5))
auglist.append(CastAug())
if brightness or contrast or saturation:
auglist.append(ColorJitterAug(brightness, contrast, saturation))
if pca_noise > 0:
eigval = np.array([55.46, 4.794, 1.148])
eigvec = np.array([[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]])
auglist.append(LightingAug(pca_noise, eigval, eigvec))
if mean is True:
mean = np.array([123.68, 116.28, 103.53])
elif mean is not None:
assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3]
if std is True:
std = np.array([58.395, 57.12, 57.375])
elif std is not None:
assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3]
if mean is not None and std is not None:
auglist.append(ColorNormalizeAug(mean, std))
return auglist
class ImageIter(io.DataIter):
"""Image data iterator with a large number of augmentation choices.
This iterator supports reading from both .rec files and raw image files.
To load input images from .rec files, use `path_imgrec` parameter and to load from raw image
files, use `path_imglist` and `path_root` parameters.
To use data partition (for distributed training) or shuffling, specify `path_imgidx` parameter.
Parameters
----------
batch_size : int
Number of examples per batch.
data_shape : tuple
Data shape in (channels, height, width) format.
For now, only RGB image with 3 channels is supported.
label_width : int, optional
Number of labels per example. The default label width is 1.
path_imgrec : str
Path to image record file (.rec).
Created with tools/im2rec.py or bin/im2rec.
path_imglist : str
Path to image list (.lst).
Created with tools/im2rec.py or with custom script.
Format: Tab separated record of index, one or more labels and relative_path_from_root.
imglist: list
A list of images with the label(s).
Each item is a list [imagelabel: float or list of float, imgpath].
path_root : str
Root folder of image files.
path_imgidx : str
Path to image index file. Needed for partition and shuffling when using .rec source.
shuffle : bool
Whether to shuffle all images at the start of each iteration or not.
Can be slow for HDD.
part_index : int
Partition index.
num_parts : int
Total number of partitions.
data_name : str
Data name for provided symbols.
label_name : str
Label name for provided symbols.
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateAugmenter.
"""
def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='softmax_label', multi_single_label=0, line_num=0, total_number=0,**kwargs):
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist or (isinstance(imglist, list))
num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1)
logging.info('Using %s threads for decoding...', str(num_threads))
logging.info('Set enviroment variable MXNET_CPU_WORKER_NTHREADS to a'
' larger number to use more threads.')
class_name = self.__class__.__name__
if path_imgrec:
logging.info('%s: loading recordio %s...',
class_name, path_imgrec)
if path_imgidx:
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec,
'r') # pylint: disable=redefined-variable-type
self.imgidx = list(self.imgrec.keys)
else:
self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type
self.imgidx = None
else:
self.imgrec = None
# dxh
self.line_num = line_num
# dxh
if path_imglist:
logging.info('%s: loading image list %s...', class_name, path_imglist)
with open(path_imglist) as fin:
imglist = {}
imgkeys = []
# dxh
line_dxh = 0
# dxh
for line in iter(fin.readline, ''):
# dxh
if self.line_num >= 2:
line_dxh += 1
if line_dxh % self.line_num == 0:
line = line.strip().split('\t')
label = nd.array([float(i) for i in line[1:-1]])
key = int(line[0])
imglist[key] = (label, line[-1])
imgkeys.append(key)
if len(imgkeys) == (int(total_number)/int(self.line_num))/int(batch_size) * int(batch_size):
break
# dxh
else:
line = line.strip().split('\t')
label = nd.array([float(i) for i in line[1:-1]])
key = int(line[0])
imglist[key] = (label, line[-1])
imgkeys.append(key)
self.imglist = imglist
elif isinstance(imglist, list):
logging.info('%s: loading image list...', class_name)
result = {}
imgkeys = []
index = 1
for img in imglist:
key = str(index) # pylint: disable=redefined-variable-type
index += 1
if len(img) > 2:
label = nd.array(img[:-1])
elif isinstance(img[0], numeric_types):
label = nd.array([img[0]])
else:
label = nd.array(img[0])
result[key] = (label, img[-1])
imgkeys.append(str(key))
self.imglist = result
else:
self.imglist = None
self.path_root = path_root
self.check_data_shape(data_shape)
self.provide_data = [(data_name, (batch_size,) + data_shape)]
###dxh
self.multi_single_label = multi_single_label
self.label_length = len(label_name)
if multi_single_label == 1:
self.provide_label = []
for dxh_i1 in range(len(label_name)):
if label_width > 1:
self.provide_label.append((label_name[dxh_i1], (batch_size, label_width)))
else:
self.provide_label.append((label_name[dxh_i1], (batch_size,)))
elif multi_single_label == 0:
if label_width > 1:
self.provide_label = [(label_name, (batch_size, label_width))]
else:
self.provide_label = [(label_name, (batch_size,))]
###dxh
self.batch_size = batch_size
self.data_shape = data_shape
self.label_width = label_width
self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
elif shuffle or num_parts > 1:
assert self.imgidx is not None
self.seq = self.imgidx
else:
self.seq = None
if num_parts > 1:
assert part_index < num_parts
N = len(self.seq)
C = N // num_parts
self.seq = self.seq[part_index * C:(part_index + 1) * C]
if aug_list is None:
self.auglist = CreateAugmenter(data_shape, **kwargs)
else:
self.auglist = aug_list
self.cur = 0
self.reset()
def reset(self):
"""Resets the iterator to the beginning of the data."""
if self.shuffle:
random.shuffle(self.seq)
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
def next_sample(self):
"""Helper function for reading in next sample."""
if self.seq is not None:
if self.cur >= len(self.seq):
raise StopIteration
idx = self.seq[self.cur]
self.cur += 1
if self.imgrec is not None:
s = self.imgrec.read_idx(idx)
header, img = recordio.unpack(s)
if self.imglist is None:
return header.label, img
else:
return self.imglist[idx][0], img
else:
label, fname = self.imglist[idx]
return label, self.read_image(fname)
else:
s = self.imgrec.read()
if s is None:
raise StopIteration
header, img = recordio.unpack(s)
return header.label, img
def next(self):
"""Returns the next batch of data."""
batch_size = self.batch_size
c, h, w = self.data_shape
batch_data = nd.empty((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = 0
try:
while i < batch_size:
label, s = self.next_sample()
data = [self.imdecode(s)]
try:
self.check_valid_image(data)
except RuntimeError as e:
logging.debug('Invalid image, skipping: %s', str(e))
continue
data = self.augmentation_transform(data)
for datum in data:
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
batch_data[i][:] = self.postprocess_data(datum)
batch_label[i][:] = label
i += 1
except StopIteration:
if not i:
raise StopIteration
# dxh
batch_lable_dxh = [batch_label]
if self.multi_single_label == 1:
for dxh_i1 in range(self.label_length - 1):
batch_label_dxh.append(batch_label)
else:
batch_label_dxh = [batch_label]
# dxh
return io.DataBatch([batch_data], batch_label_dxh, batch_size - i)
def check_data_shape(self, data_shape):
"""Checks if the input data shape is valid"""
if not len(data_shape) == 3:
raise ValueError('data_shape should have length 3, with dimensions CxHxW')
if not data_shape[0] == 3:
raise ValueError('This iterator expects inputs to have 3 channels.')
def check_valid_image(self, data):
"""Checks if the input data is valid"""
if len(data[0].shape) == 0:
raise RuntimeError('Data shape is wrong')
def imdecode(self, s):
"""Decodes a string or byte string to an NDArray.
See mx.img.imdecode for more details."""
return imdecode(s)
def read_image(self, fname):
"""Reads an input image `fname` and returns the decoded raw bytes.
Example usage:
----------
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
"""
with open(os.path.join(self.path_root, fname), 'rb') as fin:
img = fin.read()
return img
def augmentation_transform(self, data):
"""Transforms input data with specified augmentation."""
for aug in self.auglist:
data = [ret for src in data for ret in aug(src)]
return data
def postprocess_data(self, datum):
"""Final postprocessing step before image is loaded into the batch."""
return nd.transpose(datum, axes=(2, 0, 1))