diff --git a/tests/unit/models/keras/test_models.py b/tests/unit/models/keras/test_models.py index 58d77ecbc..7f5fa3efc 100644 --- a/tests/unit/models/keras/test_models.py +++ b/tests/unit/models/keras/test_models.py @@ -42,6 +42,7 @@ negative_log_likelihood, sample_with_replacement, ) +from trieste.models.keras.builders import build_keras_ensemble from trieste.models.optimizer import KerasOptimizer, TrainingData from trieste.models.utils import ( get_last_optimization_result, @@ -825,3 +826,55 @@ def test_deep_ensemble_log( assert mocked_summary_scalar.call_count == num_scalars assert mocked_summary_histogram.call_count == num_histogram + + +@pytest.mark.slow +@random_seed +def test_deep_ensemble_parallel_training_performance() -> None: + """ + Verify that doubling ensemble size doesn't double training time, as a test of parallel training. + We allow some overhead, but should be significantly less than 2x + """ + # Create a larger dataset with more features to increase computation per network + example_data = _get_example_data([10000, 10], [10000, 1]) # 10D input, 1D output + + # Test with different ensemble sizes + ensemble_sizes = [5, 10] + ensemble_units = [500, 352] # account for difference in number of parameters + training_times = [] + + for units, size in zip(ensemble_units, ensemble_sizes): + # Create a larger network to increase computation + keras_ensemble = build_keras_ensemble( + example_data, size, num_hidden_layers=3, units=units, independent_normal=True + ) + optimizer = tf_keras.optimizers.Adam() + fit_args = { + "batch_size": 512, # Larger batch size for better parallelization + "epochs": 3, # More epochs to amortize setup costs + "callbacks": [], + "verbose": 0, + } + optimizer_wrapper = KerasOptimizer(optimizer, fit_args) + model = DeepEnsemble( + keras_ensemble, + optimizer_wrapper, + True, + compile_args={"jit_compile": True}, # Enable XLA compilation + ) + print(model.model.summary()) + + # Time the training + start_time = tf.timestamp() + model.optimize(example_data) + end_time = tf.timestamp() + training_times.append(end_time - start_time) + + print(f"Training times: {training_times}") + print(f"Time ratio (10/5 networks): {training_times[1] / training_times[0]:.3f}") + + # Allow more overhead but still expect significant parallelization benefit + assert training_times[1] / training_times[0] < 1.7, ( + f"Training time ratio {training_times[1] / training_times[0]:.3f} suggests " + f"training may not be parallel" + )