-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathbatch_generator.py
34 lines (30 loc) · 940 Bytes
/
batch_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import random
import pdb
import math
def batch_gen(X, batch_size):
n_batches = X.shape[0]/float(batch_size)
n_batches = int(math.ceil(n_batches))
end = int(X.shape[0]/float(batch_size)) * batch_size
n = 0
for i in xrange(0,n_batches):
if i < n_batches - 1:
if len(X.shape) > 1:
batch = X[i*batch_size:(i+1) * batch_size, :]
yield batch
else:
batch = X[i*batch_size:(i+1) * batch_size]
yield batch
else:
if len(X.shape) > 1:
batch = X[end: , :]
n += X[end:, :].shape[0]
yield batch
else:
batch = X[end:]
n += X[end:].shape[0]
yield batch
if __name__ == "__main__":
X = np.random.rand(31223, 300)
for batch in batch_gen(X, 21):
print batch.shape