Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using tf.data for fit method instead of #890

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

hstojic
Copy link
Collaborator

@hstojic hstojic commented Jan 17, 2025

this simple change should improve memory handling, should be better optimized for GPUs, and generally gives more control to the user over preparing data for training

@hstojic hstojic requested review from uri-granta and avullo January 17, 2025 14:19
Copy link
Collaborator

@uri-granta uri-granta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Various comments/comments. Happy to review again (including the tests) once the tests are passing.

batch_size: int,
num_points: int,
validation_split: float = 0.0,
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be nicer to always return a tuple?

Suggested change
) -> Union[tf.data.Dataset, tuple[tf.data.Dataset, tf.data.Dataset]]:
) -> tuple[tf.data.Dataset, Optional[tf.data.Dataset]]]:

If validation_split > 0, returns a tuple of (training_dataset, validation_dataset)
"""
if not 0.0 <= validation_split < 1.0:
raise ValueError("validation_split must be between 0 and 1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("validation_split must be between 0 and 1")
raise ValueError(f"validation_split must be between 0 and 1: got {validation_split}")


if validation_split > 0:
# Calculate split sizes
val_size = int(num_points * validation_split)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val_size = int(num_points * validation_split)
val_size = round(num_points * validation_split)

tf_data = self.prepare_tf_data(
x,
y,
batch_size=fit_args_copy["batch_size"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"batch_size" isn't guaranteed to exist for a user-supplied fit_args

Suggested change
batch_size=fit_args_copy["batch_size"],
batch_size=fit_args_copy.get("batch_size"),


x, y = self.prepare_dataset(dataset)

validation_split = fit_args_copy.pop("validation_split", 0.0)
tf_data = self.prepare_tf_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if you change the return type above as suggested)

Suggested change
tf_data = self.prepare_tf_data(
train_dataset, val_dataset = self.prepare_tf_data(


if validation_split > 0:
train_dataset, val_dataset = tf_data
fit_args_copy["validation_data"] = val_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maybe raise an exception if "train_dataset, val_dataset" is already present in the fit_args?

Comment on lines +476 to +480
history = self.model.fit(
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs
)
else:
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
history = self.model.fit(
train_dataset, **fit_args_copy, initial_epoch=self._absolute_epochs
)
else:
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)
history = self.model.fit(tf_data, **fit_args_copy, initial_epoch=self._absolute_epochs)

# Original behavior when no validation split is requested
return (
dataset.prefetch(tf.data.AUTOTUNE)
.shuffle(train_size, reshuffle_each_iteration=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think?)

Suggested change
.shuffle(train_size, reshuffle_each_iteration=True)
.shuffle(num_points, reshuffle_each_iteration=True)


return train_dataset, val_dataset
else:
# Original behavior when no validation split is requested
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: is this really the same as the original behaviour?

@uri-granta uri-granta self-requested a review January 20, 2025 10:10
@pio-neil
Copy link

I do have a small concern about the use of tf.Dataset.shuffle. When I did some testing with this before, the shuffle buffer (which is tf.Dataset's internal method of shuffling data) used around 18GB of extra memory. This was with a dataset with 30 million rows, with 15 inputs and one output, and a batch size of 1000. The shuffle buffer also has an impact on speed, but I suspect this is relatively minor compared to the model training time.

This might not be such a problem with smaller datasets. So perhaps it would be a good idea to do some benchmarking?

However, it's also not clear to me why we're introducing shuffling by default here, when AFAICT it wasn't there before? This seems like a change of behaviour. Do we expect this to improve model accuracy? It may be better to let the user of Trieste control this, rather than making it the default behaviour?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants