Skip to content

Commit

Permalink
no need to check dimension each time a data generator generate a batch
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Sep 7, 2024
1 parent 831442b commit 1352dba
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 70 deletions.
14 changes: 7 additions & 7 deletions src/astroNN/models/base_bayesian_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
)

def _data_generation(self, idx_list_temp):
x = self.input_d_checking(self.inputs, idx_list_temp)
x = self.get_idx_item(self.inputs, idx_list_temp)
if "labels_err" in x.keys():
x.update({"labels_err": np.squeeze(x["labels_err"])})
y = {}
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(

def _data_generation(self, idx_list_temp):
# Generate data
x = self.input_d_checking(self.inputs, idx_list_temp)
x = self.get_idx_item(self.inputs, idx_list_temp)
return x

def __getitem__(self, index):
Expand Down Expand Up @@ -226,8 +226,8 @@ def pre_training_checklist_child(self, input_data, labels, sample_weight):
# No need to care about Magic number as loss function looks for magic num in y_true only
norm_data.update(
{
"input_err": (input_data["input_err"] / self.input_std["input"]),
"labels_err": input_data["labels_err"] / self.labels_std["output"],
"input_err": (input_data["input_err"] / self.input_std["input"]).astype(np.float32),
"labels_err": (input_data["labels_err"] / self.labels_std["output"]).astype(np.float32),
}
)
norm_labels.update({"variance_output": norm_labels["output"]})
Expand Down Expand Up @@ -682,8 +682,8 @@ def fit_on_batch(
# No need to care about Magic number as loss function looks for magic num in y_true only
norm_data.update(
{
"input_err": (input_data["input_err"] / self.input_std["input"]),
"labels_err": input_data["labels_err"] / self.labels_std["output"],
"input_err": (input_data["input_err"] / self.input_std["input"]).astype(np.float32),
"labels_err": (input_data["labels_err"] / self.labels_std["output"]).astype(np.float32),
}
)
norm_labels.update({"variance_output": norm_labels["output"]})
Expand Down Expand Up @@ -956,7 +956,7 @@ def _data_generation(self, idx_list_temp):
},
calc=False,
)
x = self.input_d_checking(inputs, np.arange(len(idx_list_temp)))
x = self.get_idx_item(inputs, np.arange(len(idx_list_temp)))
return x

def __getitem__(self, index):
Expand Down
4 changes: 2 additions & 2 deletions src/astroNN/models/base_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
)

def _data_generation(self, idx_list_temp):
x = self.input_d_checking(self.inputs, idx_list_temp)
x = self.get_idx_item(self.inputs, idx_list_temp)
y = {}
for name in self.labels.keys():
y.update({name: self.labels[name][idx_list_temp]})
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, batch_size, shuffle, steps_per_epoch, data, pbar=None):

def _data_generation(self, idx_list_temp):
# Generate data
x = self.input_d_checking(self.inputs, idx_list_temp)
x = self.get_idx_item(self.inputs, idx_list_temp)
return x

def __getitem__(self, index):
Expand Down
6 changes: 3 additions & 3 deletions src/astroNN/models/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(
)

def _data_generation(self, idx_list_temp):
x = self.input_d_checking(self.inputs, idx_list_temp)
y = self.input_d_checking(self.recon_inputs, idx_list_temp)
x = self.get_idx_item(self.inputs, idx_list_temp)
y = self.get_idx_item(self.recon_inputs, idx_list_temp)
if self.sample_weight is not None:
return x, y, self.sample_weight[idx_list_temp]
else:
Expand Down Expand Up @@ -130,7 +130,7 @@ def __init__(

def _data_generation(self, idx_list_temp):
# Generate data
x = self.input_d_checking(self.inputs, idx_list_temp)
x = self.get_idx_item(self.inputs, idx_list_temp)
return x

def __getitem__(self, index):
Expand Down
8 changes: 6 additions & 2 deletions src/astroNN/models/nn_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,18 @@ def pre_training_checklist_master(self, input_data, labels):
# handle named inputs/outputs first
try:
self.input_names = list(input_data.keys())
# if input_data is a dict, cast all values to float32
input_data = {name: input_data[name].astype(np.float32) for name in self.input_names}
except AttributeError:
self.input_names = ["input"] # default input name in all astroNN models
input_data = {"input": input_data}
input_data = {"input": input_data.astype(np.float32)}
try:
self.output_names = list(labels.keys())
# if labels is a dict, cast all values to float32
labels = {name: labels[name].astype(np.float32) for name in self.output_names}
except AttributeError:
self.output_names = ["output"] # default input name in all astroNN models
labels = {"output": labels}
labels = {"output": labels.astype(np.float32)}

# assert all named input has the same number of data points
# TODO: add detail error msg, add test
Expand Down
71 changes: 18 additions & 53 deletions src/astroNN/nn/utilities/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ class GeneratorBase(keras.utils.PyDataset):
Parameters
----------
batch_size: int
data: dict
data dictionary
batch_size: int, optional (default is 64)
batch size
shuffle: bool
shuffle: bool, optional (default is True)
shuffle the data or not after each epoch
steps_per_epoch: int
steps_per_epoch: int, optional (default is None)
steps per epoch
data: dict
data dictionary
np_rng: numpy.random.Generator
np_rng: numpy.random.Generator, optional (default is None)
numpy random generator
History
Expand All @@ -28,7 +28,7 @@ class GeneratorBase(keras.utils.PyDataset):
2024-Sept-6 - Updated - Henry Leung (University of Toronto)
"""

def __init__(self, data, *, batch_size=32, shuffle=True, steps_per_epoch=None, np_rng=None, **kwargs):
def __init__(self, data, *, batch_size=64, shuffle=True, steps_per_epoch=None, np_rng=None, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.data = data
Expand All @@ -54,49 +54,14 @@ def _get_exploration_order(self, idx_list):
self.np_rng.shuffle(idx_list)

return idx_list

def input_d_checking(self, inputs, idx_list_temp):
x_dict = {}
float_dtype = keras.backend.floatx()
for name in inputs.keys():
if inputs[name].ndim == 2:
x = np.empty(
(len(idx_list_temp), inputs[name].shape[1], 1),
dtype=float_dtype,
)
# Generate data
x[:, :, 0] = inputs[name][idx_list_temp]

elif inputs[name].ndim == 3:
x = np.empty(
(
len(idx_list_temp),
inputs[name].shape[1],
inputs[name].shape[2],
1,
),
dtype=float_dtype,
)
# Generate data
x[:, :, :, 0] = inputs[name][idx_list_temp]

elif inputs[name].ndim == 4:
x = np.empty(
(
len(idx_list_temp),
inputs[name].shape[1],
inputs[name].shape[2],
inputs[name].shape[3],
),
dtype=float_dtype,
)
# Generate data
x[:, :, :, :] = inputs[name][idx_list_temp]
else:
raise ValueError(
f"Unsupported data dimension, your data has {inputs[name].ndim} dimension"
)

x_dict.update({name: x})

return x_dict

def get_idx_item(self, data, idx):
"""
Get batch data with index
"""
if isinstance(data, dict):
return {key: data[key][idx] for key in data.keys()}
elif isinstance(data, list):
return [data[i][idx] for i in range(len(data))]
else:
return data[idx]
3 changes: 0 additions & 3 deletions src/astroNN/nn/utilities/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def mode_checker(self, data):
"doing nothing because no normalization can be done on bool"
)
self.normalization_mode[name] = "0"
data_array = data_array.astype(
np.float32, copy=False
) # need to convert data to float in every case

if self.normalization_mode[name] == "0":
self.featurewise_center.update({name: False})
Expand Down

0 comments on commit 1352dba

Please sign in to comment.