-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsynthetic_pipeline_reading.py
64 lines (57 loc) · 2.15 KB
/
synthetic_pipeline_reading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
import numpy as np
import tensorflow as tf
def pipeline_definition(shape, sample_count, data_type):
if data_type == "float32":
dtype = tf.float32
create_dataset = float32_dataset(shape=shape, sample_count=sample_count)
elif data_type == "uint8":
dtype = tf.uint8
create_dataset = uint8_dataset(shape=shape, sample_count=sample_count)
else:
print(f"ERROR: Unsupported dtype in pipeline_definition: {dtype}")
sys.exit(1)
return [
{
"name": f"create-dataset-{data_type}",
"type": "source",
"op": create_dataset,
"input_schema": tf.TensorSpec([None, shape[1]], dtype),
"output_schema": tf.TensorSpec([None, shape[1]], dtype)
},
{
"name": "identity",
"type": "op",
"op": tf.identity,
"input_schema": tf.TensorSpec([None, shape[1]], dtype),
"output_schema": tf.TensorSpec([None, shape[1]], dtype)
},
]
def uint8_dataset(shape, sample_count):
'''
'''
def uint8_generator(shape, sample_count):
'''
'''
for _ in range(sample_count):
yield np.random.randint(low=0, high=255, size=shape, dtype=np.uint8)
generator = lambda: uint8_generator(shape=shape, sample_count=sample_count)
ds = tf.data.Dataset.from_generator(generator = generator
,output_types= tf.uint8
,output_shapes=(tf.TensorShape([None, shape[1]])))
return ds
def float32_dataset(shape, sample_count):
'''
'''
def float32_generator(shape, sample_count):
'''
'''
min_float = -2**15
max_float = 2**15-1
for _ in range(sample_count):
yield np.random.uniform(low=min_float, high=max_float, size=shape).astype(np.float32)
generator = lambda: float32_generator(shape=shape, sample_count=sample_count)
ds = tf.data.Dataset.from_generator(generator = generator
,output_types= tf.float32
,output_shapes=(tf.TensorShape([None, shape[1]])))
return ds