diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 52f8c3a48..0184b87b4 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import re -from collections import OrderedDict +from collections import OrderedDict, namedtuple from importlib import import_module import numpy as np @@ -89,17 +89,16 @@ def __init__(self, raw_dist, raw_params, expected_value_domain, xfail_reason="") self.raw_dist = re.sub(r"\s+", " ", raw_dist.strip()) self.raw_params = raw_params self.expected_value_domain = expected_value_domain - for name, raw_param in self.raw_params: - if get_backend() != "numpy": - # we need direct access to these tensors for gradient tests - setattr(self, name, eval(raw_param)) TEST_CASES.append( self if not xfail_reason else xfail_param(self, reason=xfail_reason) ) def get_dist(self): dist = backend_dist # noqa: F841 - case = self # noqa: F841 + Case = namedtuple("Case", tuple(name for name, _ in self.raw_params)) + case = Case( # noqa: F841 + **{name: eval(raw_param) for name, raw_param in self.raw_params} + ) with xfail_if_not_found(): return eval(self.raw_dist)