From 5800d056daadee9eadbb85c2843f6634cfab7963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Thu, 30 Nov 2023 12:58:53 +0100 Subject: [PATCH] fix shape --- bofire/strategies/samplers/polytope.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bofire/strategies/samplers/polytope.py b/bofire/strategies/samplers/polytope.py index d2cdc1719..444db3394 100644 --- a/bofire/strategies/samplers/polytope.py +++ b/bofire/strategies/samplers/polytope.py @@ -143,7 +143,7 @@ def _ask(self, n: int) -> pd.DataFrame: n_burnin=self.n_burnin, thinning=self.n_thinning, seed=self.seed, - ) + ).squeeze(dim=0) # check that the random generated candidates are not always the same if (candidates.unique(dim=0).shape[0] != n) and (n > 1): @@ -154,10 +154,9 @@ def _ask(self, n: int) -> pd.DataFrame: for feat in self.domain.get_features(ContinuousInput) if feat.key not in fixed_features.keys() # type: ignore ] - # setup the output samples = pd.DataFrame( - data=candidates.detach().numpy().reshape(n, len(free_continuals)), + data=candidates.detach().numpy(), index=range(n), columns=free_continuals, )