diff --git a/src/astroNN/models/base_bayesian_cnn.py b/src/astroNN/models/base_bayesian_cnn.py index 8b7114f8..72b08ea6 100644 --- a/src/astroNN/models/base_bayesian_cnn.py +++ b/src/astroNN/models/base_bayesian_cnn.py @@ -81,7 +81,7 @@ def __init__( ) def _data_generation(self, idx_list_temp): - x = self.input_d_checking(self.inputs, idx_list_temp) + x = self.get_idx_item(self.inputs, idx_list_temp) if "labels_err" in x.keys(): x.update({"labels_err": np.squeeze(x["labels_err"])}) y = {} @@ -141,7 +141,7 @@ def __init__( def _data_generation(self, idx_list_temp): # Generate data - x = self.input_d_checking(self.inputs, idx_list_temp) + x = self.get_idx_item(self.inputs, idx_list_temp) return x def __getitem__(self, index): @@ -226,8 +226,8 @@ def pre_training_checklist_child(self, input_data, labels, sample_weight): # No need to care about Magic number as loss function looks for magic num in y_true only norm_data.update( { - "input_err": (input_data["input_err"] / self.input_std["input"]), - "labels_err": input_data["labels_err"] / self.labels_std["output"], + "input_err": (input_data["input_err"] / self.input_std["input"]).astype(np.float32), + "labels_err": (input_data["labels_err"] / self.labels_std["output"]).astype(np.float32), } ) norm_labels.update({"variance_output": norm_labels["output"]}) @@ -682,8 +682,8 @@ def fit_on_batch( # No need to care about Magic number as loss function looks for magic num in y_true only norm_data.update( { - "input_err": (input_data["input_err"] / self.input_std["input"]), - "labels_err": input_data["labels_err"] / self.labels_std["output"], + "input_err": (input_data["input_err"] / self.input_std["input"]).astype(np.float32), + "labels_err": (input_data["labels_err"] / self.labels_std["output"]).astype(np.float32), } ) norm_labels.update({"variance_output": norm_labels["output"]}) @@ -956,7 +956,7 @@ def _data_generation(self, idx_list_temp): }, calc=False, ) - x = self.input_d_checking(inputs, np.arange(len(idx_list_temp))) + x = self.get_idx_item(inputs, np.arange(len(idx_list_temp))) return x def __getitem__(self, index): diff --git a/src/astroNN/models/base_cnn.py b/src/astroNN/models/base_cnn.py index d5c616bd..22124a34 100644 --- a/src/astroNN/models/base_cnn.py +++ b/src/astroNN/models/base_cnn.py @@ -67,7 +67,7 @@ def __init__( ) def _data_generation(self, idx_list_temp): - x = self.input_d_checking(self.inputs, idx_list_temp) + x = self.get_idx_item(self.inputs, idx_list_temp) y = {} for name in self.labels.keys(): y.update({name: self.labels[name][idx_list_temp]}) @@ -123,7 +123,7 @@ def __init__(self, batch_size, shuffle, steps_per_epoch, data, pbar=None): def _data_generation(self, idx_list_temp): # Generate data - x = self.input_d_checking(self.inputs, idx_list_temp) + x = self.get_idx_item(self.inputs, idx_list_temp) return x def __getitem__(self, index): diff --git a/src/astroNN/models/base_vae.py b/src/astroNN/models/base_vae.py index 42cc24ae..1423c92c 100644 --- a/src/astroNN/models/base_vae.py +++ b/src/astroNN/models/base_vae.py @@ -65,8 +65,8 @@ def __init__( ) def _data_generation(self, idx_list_temp): - x = self.input_d_checking(self.inputs, idx_list_temp) - y = self.input_d_checking(self.recon_inputs, idx_list_temp) + x = self.get_idx_item(self.inputs, idx_list_temp) + y = self.get_idx_item(self.recon_inputs, idx_list_temp) if self.sample_weight is not None: return x, y, self.sample_weight[idx_list_temp] else: @@ -130,7 +130,7 @@ def __init__( def _data_generation(self, idx_list_temp): # Generate data - x = self.input_d_checking(self.inputs, idx_list_temp) + x = self.get_idx_item(self.inputs, idx_list_temp) return x def __getitem__(self, index): diff --git a/src/astroNN/models/nn_base.py b/src/astroNN/models/nn_base.py index 171c0e87..d4c4923b 100644 --- a/src/astroNN/models/nn_base.py +++ b/src/astroNN/models/nn_base.py @@ -201,14 +201,18 @@ def pre_training_checklist_master(self, input_data, labels): # handle named inputs/outputs first try: self.input_names = list(input_data.keys()) + # if input_data is a dict, cast all values to float32 + input_data = {name: input_data[name].astype(np.float32) for name in self.input_names} except AttributeError: self.input_names = ["input"] # default input name in all astroNN models - input_data = {"input": input_data} + input_data = {"input": input_data.astype(np.float32)} try: self.output_names = list(labels.keys()) + # if labels is a dict, cast all values to float32 + labels = {name: labels[name].astype(np.float32) for name in self.output_names} except AttributeError: self.output_names = ["output"] # default input name in all astroNN models - labels = {"output": labels} + labels = {"output": labels.astype(np.float32)} # assert all named input has the same number of data points # TODO: add detail error msg, add test diff --git a/src/astroNN/nn/utilities/generator.py b/src/astroNN/nn/utilities/generator.py index c96e2646..7525d5ff 100644 --- a/src/astroNN/nn/utilities/generator.py +++ b/src/astroNN/nn/utilities/generator.py @@ -11,15 +11,15 @@ class GeneratorBase(keras.utils.PyDataset): Parameters ---------- - batch_size: int + data: dict + data dictionary + batch_size: int, optional (default is 64) batch size - shuffle: bool + shuffle: bool, optional (default is True) shuffle the data or not after each epoch - steps_per_epoch: int + steps_per_epoch: int, optional (default is None) steps per epoch - data: dict - data dictionary - np_rng: numpy.random.Generator + np_rng: numpy.random.Generator, optional (default is None) numpy random generator History @@ -28,7 +28,7 @@ class GeneratorBase(keras.utils.PyDataset): 2024-Sept-6 - Updated - Henry Leung (University of Toronto) """ - def __init__(self, data, *, batch_size=32, shuffle=True, steps_per_epoch=None, np_rng=None, **kwargs): + def __init__(self, data, *, batch_size=64, shuffle=True, steps_per_epoch=None, np_rng=None, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size self.data = data @@ -54,49 +54,14 @@ def _get_exploration_order(self, idx_list): self.np_rng.shuffle(idx_list) return idx_list - - def input_d_checking(self, inputs, idx_list_temp): - x_dict = {} - float_dtype = keras.backend.floatx() - for name in inputs.keys(): - if inputs[name].ndim == 2: - x = np.empty( - (len(idx_list_temp), inputs[name].shape[1], 1), - dtype=float_dtype, - ) - # Generate data - x[:, :, 0] = inputs[name][idx_list_temp] - - elif inputs[name].ndim == 3: - x = np.empty( - ( - len(idx_list_temp), - inputs[name].shape[1], - inputs[name].shape[2], - 1, - ), - dtype=float_dtype, - ) - # Generate data - x[:, :, :, 0] = inputs[name][idx_list_temp] - - elif inputs[name].ndim == 4: - x = np.empty( - ( - len(idx_list_temp), - inputs[name].shape[1], - inputs[name].shape[2], - inputs[name].shape[3], - ), - dtype=float_dtype, - ) - # Generate data - x[:, :, :, :] = inputs[name][idx_list_temp] - else: - raise ValueError( - f"Unsupported data dimension, your data has {inputs[name].ndim} dimension" - ) - - x_dict.update({name: x}) - - return x_dict + + def get_idx_item(self, data, idx): + """ + Get batch data with index + """ + if isinstance(data, dict): + return {key: data[key][idx] for key in data.keys()} + elif isinstance(data, list): + return [data[i][idx] for i in range(len(data))] + else: + return data[idx] diff --git a/src/astroNN/nn/utilities/normalizer.py b/src/astroNN/nn/utilities/normalizer.py index eff37cef..3797eb1b 100644 --- a/src/astroNN/nn/utilities/normalizer.py +++ b/src/astroNN/nn/utilities/normalizer.py @@ -70,9 +70,6 @@ def mode_checker(self, data): "doing nothing because no normalization can be done on bool" ) self.normalization_mode[name] = "0" - data_array = data_array.astype( - np.float32, copy=False - ) # need to convert data to float in every case if self.normalization_mode[name] == "0": self.featurewise_center.update({name: False})