diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index d22dce661a..8e8176441f 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -187,8 +187,9 @@ def __post__init__(self): or self.dargs.precision == "amp" or self.dargs.precision == "int8-static" or self.dargs.precision == "int8-dynamic" - or os.getenv("DNNL_DEFAULT_FPMATH_MODE").lower() == "any" - or os.getenv("DNNL_DEFAULT_FPMATH_MODE").lower() == "bf16" + or (os.getenv("DNNL_DEFAULT_FPMATH_MODE") != None and + (os.getenv("DNNL_DEFAULT_FPMATH_MODE").lower() == "any" or + os.getenv("DNNL_DEFAULT_FPMATH_MODE").lower() == "bf16")) or (self.dynamo and self.opt_args.torchdynamo == "fx2trt") or (not self.dynamo and self.opt_args.fx2trt) or (not self.dynamo and self.opt_args.use_cosine_similarity)