-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using tf.data for fit method instead of #890
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Various comments/comments. Happy to review again (including the tests) once the tests are passing.
batch_size: int, | ||
num_points: int, | ||
validation_split: float = 0.0, | ||
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be nicer to always return a tuple?
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]: | |
) -> tuple[tf.data.Dataset, Optional[tf.data.Dataset]]]: |
If validation_split > 0, returns a tuple of (training_dataset, validation_dataset) | ||
""" | ||
if not 0.0 <= validation_split < 1.0: | ||
raise ValueError("validation_split must be between 0 and 1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError("validation_split must be between 0 and 1") | |
raise ValueError(f"validation_split must be between 0 and 1: got {validation_split}") |
|
||
if validation_split > 0: | ||
# Calculate split sizes | ||
val_size = int(num_points * validation_split) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val_size = int(num_points * validation_split) | |
val_size = round(num_points * validation_split) |
tf_data = self.prepare_tf_data( | ||
x, | ||
y, | ||
batch_size=fit_args_copy["batch_size"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"batch_size" isn't guaranteed to exist for a user-supplied fit_args
batch_size=fit_args_copy["batch_size"], | |
batch_size=fit_args_copy.get("batch_size"), |
|
||
x, y = self.prepare_dataset(dataset) | ||
|
||
validation_split = fit_args_copy.pop("validation_split", 0.0) | ||
tf_data = self.prepare_tf_data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(if you change the return type above as suggested)
tf_data = self.prepare_tf_data( | |
train_dataset, val_dataset = self.prepare_tf_data( |
|
||
if validation_split > 0: | ||
train_dataset, val_dataset = tf_data | ||
fit_args_copy["validation_data"] = val_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we maybe raise an exception if "train_dataset, val_dataset" is already present in the fit_args?
history = self.model.fit( | ||
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs | ||
) | ||
else: | ||
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
history = self.model.fit( | |
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs | |
) | |
else: | |
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs) | |
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs) |
# Original behavior when no validation split is requested | ||
return ( | ||
dataset.prefetch(tf.data.AUTOTUNE) | ||
.shuffle(train_size, reshuffle_each_iteration=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I think?)
.shuffle(train_size, reshuffle_each_iteration=True) | |
.shuffle(num_points, reshuffle_each_iteration=True) |
|
||
return train_dataset, val_dataset | ||
else: | ||
# Original behavior when no validation split is requested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: is this really the same as the original behaviour?
I do have a small concern about the use of This might not be such a problem with smaller datasets. So perhaps it would be a good idea to do some benchmarking? However, it's also not clear to me why we're introducing shuffling by default here, when AFAICT it wasn't there before? This seems like a change of behaviour. Do we expect this to improve model accuracy? It may be better to let the user of Trieste control this, rather than making it the default behaviour? |
this simple change should improve memory handling, should be better optimized for GPUs, and generally gives more control to the user over preparing data for training