Skip to content

Commit

Permalink
Remove nondeterminism in test_distribution_generic (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Feb 3, 2021
1 parent 155ad4c commit 7a56805
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions test/test_distribution_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import re
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from importlib import import_module

import numpy as np
Expand Down Expand Up @@ -89,17 +89,16 @@ def __init__(self, raw_dist, raw_params, expected_value_domain, xfail_reason="")
self.raw_dist = re.sub(r"\s+", " ", raw_dist.strip())
self.raw_params = raw_params
self.expected_value_domain = expected_value_domain
for name, raw_param in self.raw_params:
if get_backend() != "numpy":
# we need direct access to these tensors for gradient tests
setattr(self, name, eval(raw_param))
TEST_CASES.append(
self if not xfail_reason else xfail_param(self, reason=xfail_reason)
)

def get_dist(self):
dist = backend_dist # noqa: F841
case = self # noqa: F841
Case = namedtuple("Case", tuple(name for name, _ in self.raw_params))
case = Case( # noqa: F841
**{name: eval(raw_param) for name, raw_param in self.raw_params}
)
with xfail_if_not_found():
return eval(self.raw_dist)

Expand Down

0 comments on commit 7a56805

Please sign in to comment.