Skip to content

Commit

Permalink
Disable tf.where python tests (#764)
Browse files Browse the repository at this point in the history
* Fix errors in pytests scripts that prevented "Where" unit tests from running
* Skip tf.where tests
  • Loading branch information
bani-intelaipg authored Dec 18, 2020
1 parent 26fd7bd commit 69c4b65
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
7 changes: 4 additions & 3 deletions test/python/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ def is_env_variable_set(self, env_var):

# sets the env variable
def set_env_variable(self, env_var, env_var_val):
os.putenv(env_var, env_var_val)
os.environ[env_var] = env_var_val
print("Setting env variable ", env_var, " to ", env_var_val)

# unset the env variable
def unset_env_variable(self, env_var):
os.unsetenv(env_var)
print("Unset env variable ", env_var)
if self.is_env_variable_set(env_var):
os.environ.pop(env_var)
print("Unset env variable ", env_var)

# get the env variable
def get_env_variable(self, env_var):
Expand Down
5 changes: 4 additions & 1 deletion test/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def pattern_to_regex(pattern):
pattern = re.sub(r'\*', '.*', pattern)
# special case for M.C.F when it possibly matches with parameterized tests
if pattern_noparam.count('.') == 2 and no_param:
pattern = '^' + pattern + r'\[.*'
if no_param:
pattern = '^' + pattern + '$'
else:
pattern = '^' + pattern + r'\[.*'
if pattern_noparam.count('.') == 0:
pattern = '^' + pattern + r'\..*\..*' + '$'
if pattern_noparam.count('.') == 1:
Expand Down
2 changes: 1 addition & 1 deletion test/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def test_set_backend_invalid(self):
except:
error_thrown = True
ngraph_bridge.set_backend(current_backend)
assert error_thrown
self.restore_env_variables(env_var_map)
assert error_thrown

def test_list_backends(self):
assert len(ngraph_bridge.list_backends())
Expand Down
8 changes: 4 additions & 4 deletions test/python/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ def run_test(sess):
class TestWhere(NgraphTest):
env_map = None

def __init__(self):
env_map = self.store_env_variables(['NGRAPH_TF_CONSTANT_FOLDING'])
def setup_method(self):
self.env_map = self.store_env_variables(['NGRAPH_TF_CONSTANT_FOLDING'])
self.set_env_variable('NGRAPH_TF_CONSTANT_FOLDING', '1')

def __del__(self):
self.restore_env_variables(env_map)
def teardown_method(self):
self.restore_env_variables(self.env_map)

def test_where(self):
a = np.array([1.1, 3.0], [2.2, 4.4]).astype(np.float32)
Expand Down
3 changes: 3 additions & 0 deletions test/python/tests_common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@

test_bfloat16.TestBfloat16.test_matmul_bfloat16 # most backends do not support dtype bf16 for MatMul/Dot Op
test_conv2dbackpropinput.TestConv2DBackpropInput.test_nhwc # parameterized; Fails, needs debugging

test_select.TestWhere.* # Where op translation not working yet

0 comments on commit 69c4b65

Please sign in to comment.