From b403d2934e1a6affcccd666efaa6eb79f161036f Mon Sep 17 00:00:00 2001 From: FeSens Date: Mon, 6 Jun 2022 11:02:06 -0300 Subject: [PATCH 1/2] assert inputs are the rigth type --- uq360/algorithms/variational_bayesian_neural_networks/bnn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/uq360/algorithms/variational_bayesian_neural_networks/bnn.py b/uq360/algorithms/variational_bayesian_neural_networks/bnn.py index 50b2914..5caf0c2 100644 --- a/uq360/algorithms/variational_bayesian_neural_networks/bnn.py +++ b/uq360/algorithms/variational_bayesian_neural_networks/bnn.py @@ -64,6 +64,9 @@ def fit(self, X, y): self """ + assert type(X) == torch.Tensor, f"Expected X type torch.Tensor but found {type(X)}" + assert type(y) == torch.Tensor, f"Expected X type torch.Tensor but found {type(y)}" + torch.manual_seed(1234) optimizer = torch.optim.Adam(self.net.parameters(), lr=self.config['step_size']) neg_elbo = torch.zeros([self.config['num_epochs'], 1]) From 7fef313cc9b16af15b547791ab85894f8a789f66 Mon Sep 17 00:00:00 2001 From: FeSens Date: Mon, 6 Jun 2022 11:28:33 -0300 Subject: [PATCH 2/2] fix typo in the error message --- uq360/algorithms/variational_bayesian_neural_networks/bnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uq360/algorithms/variational_bayesian_neural_networks/bnn.py b/uq360/algorithms/variational_bayesian_neural_networks/bnn.py index 5caf0c2..f66d5aa 100644 --- a/uq360/algorithms/variational_bayesian_neural_networks/bnn.py +++ b/uq360/algorithms/variational_bayesian_neural_networks/bnn.py @@ -65,7 +65,7 @@ def fit(self, X, y): """ assert type(X) == torch.Tensor, f"Expected X type torch.Tensor but found {type(X)}" - assert type(y) == torch.Tensor, f"Expected X type torch.Tensor but found {type(y)}" + assert type(y) == torch.Tensor, f"Expected y type torch.Tensor but found {type(y)}" torch.manual_seed(1234) optimizer = torch.optim.Adam(self.net.parameters(), lr=self.config['step_size'])