From 71ea41b3723b9ee94b5d9aff38f8cd69b944adce Mon Sep 17 00:00:00 2001 From: SaiAakash Date: Sat, 9 Nov 2024 17:53:24 +0000 Subject: [PATCH 1/2] detach new_covar_cache to enable tracing for fantasized models --- gpytorch/models/exact_prediction_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 0a8092e15..0e565cbbd 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -211,7 +211,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ # now update the root and root inverse new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar) new_root = new_lt.root_decomposition().root - new_covar_cache = new_lt.root_inv_decomposition().root + new_covar_cache = new_lt.root_inv_decomposition().root.detach() # Expand inputs accordingly if necessary (for fantasies at the same points) if full_inputs[0].dim() <= full_targets.dim(): From 2363eba474ab7c8056ca04bce49fbb00c9a811e2 Mon Sep 17 00:00:00 2001 From: SaiAakash Date: Wed, 13 Nov 2024 22:34:34 +0000 Subject: [PATCH 2/2] detach only if detach_test_caches is on --- gpytorch/models/exact_prediction_strategies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 0e565cbbd..b0b14aedf 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -211,7 +211,10 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ # now update the root and root inverse new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar) new_root = new_lt.root_decomposition().root - new_covar_cache = new_lt.root_inv_decomposition().root.detach() + if settings.detach_test_caches.on(): + new_covar_cache = new_lt.root_inv_decomposition().root.detach() + else: + new_covar_cache = new_lt.root_inv_decomposition().root # Expand inputs accordingly if necessary (for fantasies at the same points) if full_inputs[0].dim() <= full_targets.dim():