Skip to content

Commit

Permalink
Merge pull request #340 from sony/feature/20210909-disable-tf32
Browse files Browse the repository at this point in the history
Disable TF32 by default
  • Loading branch information
KazukiYoshiyama-sony authored Sep 16, 2021
2 parents 8c681d2 + 5d465ea commit 926a3a1
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/nbla/cuda/init.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright 2017,2018,2019,2020,2021 Sony Corporation.
// Copyright 2021 Sony Group Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,6 +71,11 @@ get_cuda_virtual_caching_allocator_max_available_bytes(const string &device_id);
NBLA_CUDA_API vector<int>
get_cuda_virtual_memory_used_counts(const string &device_id);

/**
* Check if tf32 is enabled or not.
*/
NBLA_CUDA_API bool is_cuda_tf32_enabled();

/** Get CUDA array classes.
*/
NBLA_CUDA_API vector<string> cuda_array_classes();
Expand Down
6 changes: 6 additions & 0 deletions python/src/nnabla_ext/cuda/init.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,6 +37,7 @@ cdef extern from "nbla/cuda/init.hpp" namespace "nbla":
size_t get_cuda_virtual_caching_allocator_fragmentation_bytes(const string& device_id) except +
size_t get_cuda_virtual_caching_allocator_max_available_bytes(const string& device_id) except +
vector[int] get_cuda_virtual_memory_used_counts(const string& device_id) except +
bool is_cuda_tf32_enabled() except+
vector[string] cuda_array_classes() except +
void _cuda_set_array_classes(const vector[string] & a) except +
void cuda_device_synchronize(const string & device) except +
Expand Down Expand Up @@ -109,6 +111,10 @@ def get_virtual_memory_used_counts(str device_id):
"""Get # of cuda virtual memory which is currently used."""
return get_cuda_virtual_memory_used_counts(device_id)

def is_tf32_enabled():
"""Check if tf32 is enabled or not."""
return is_cuda_tf32_enabled()

###############################################################################
# Array preference API
# TODO: Move these to C++
Expand Down
55 changes: 55 additions & 0 deletions python/test/cuda/test_allow_tf32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2017,2018,2019,2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import pytest
import os
import subprocess


def teardown_function(function):
if os.getenv('NNABLA_CUDA_ALLOW_TF32'):
del os.environ["NNABLA_CUDA_ALLOW_TF32"]
if os.getenv('NVIDIA_TF32_OVERRIDE'):
del os.environ["NVIDIA_TF32_OVERRIDE"]


'''
This test is verify that tf32 is set correctly by the environment variables.
'''


@pytest.mark.parametrize("nnabla_tf32", [None, "0", "1"])
@pytest.mark.parametrize("nvidia_tf32", [None, "0", "1"])
@pytest.mark.parametrize("context", ['cuda', 'cudnn'])
def test_allow_tf32(nnabla_tf32, nvidia_tf32, context):

if nnabla_tf32:
os.environ["NNABLA_CUDA_ALLOW_TF32"] = nnabla_tf32
if nvidia_tf32:
os.environ["NVIDIA_TF32_OVERRIDE"] = nvidia_tf32

d = os.path.dirname(os.path.abspath(__file__))
test_main_path = d + "/test_allow_tf32_main.py"

if nnabla_tf32 == "1":
result = subprocess.run(
["python", test_main_path, "--context", context, "--allow-tf32"])
else:
result = subprocess.run(
["python", test_main_path, "--context", context])

assert result.returncode == 0
46 changes: 46 additions & 0 deletions python/test/cuda/test_allow_tf32_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2017,2018,2019,2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import nnabla as nn
import os
import argparse

'''
This test is called by test_allow_tf32.py to verify that tf32 is set correctly by the environment variables.
In pytest, init_cuda is executed only once throughout the whole tests, and this test case is added as a separate test because the correct behavior could not be confirmed.
'''


def main(args):
context = args.context

from nnabla.ext_utils import get_extension_context
with nn.context_scope(get_extension_context(context)):
import nnabla_ext.cuda.init as cuda_init
allow_tf32 = cuda_init.is_tf32_enabled()

assert args.allow_tf32 == allow_tf32


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--allow-tf32", action='store_true', default=False)
parser.add_argument("--context", type=str,
default="cuda", choices=["cudnn", "cuda"])
args = parser.parse_args()
main(args)
5 changes: 5 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ However, it often consumes much memory due to a big workspace memory required by

In some cases it may be desired to restrict the automatic search for CUDNN Convolution algorithms to those that give deterministic (reproducable) results. This can be achived by setting an environment variable `NNABLA_CUDNN_DETERMINISTIC` to some value other than `0`.

### TensorFloat-32 (TF32)

In NNabla, the environment variable `NNABLA_CUDA_ALLOW_TF32` controls whether TF32 (about TF32, see [a blog post](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) from NVIDIA) is allowed to be used. If `NNABLA_CUDA_ALLOW_TF32` is not set (default) or 0, TF32 is disabled. Otherwise, it is enabled. `NNABLA_CUDA_ALLOW_TF32` always takes priority of `NVIDIA_TF32_OVERRIDE`. `NNABLA_CUDA_ALLOW_TF32` is only evaluated when initializing NNabla CUDA extension. If it is changed within the user program, the behavior is undefined.


## FAQ

No FAQ so far.
11 changes: 11 additions & 0 deletions src/nbla/cuda/cublas.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright 2017,2018,2019,2020,2021 Sony Corporation.
// Copyright 2021 Sony Group Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,11 +69,16 @@ void cublas_gemm<half>(cublasHandle_t handle, cublasOperation_t op_x,
cudaDeviceProp prop = cuda_get_current_device_properties();
if (prop.major >= 5) {
auto ct = cuda_data_type<typename CudaTypeForceFloat<half>::type>::type();
#if CUDA_VERSION < 11000
// CUBLAS_TENSOR_OP_MATH is deprecated in CUDA 11.0
NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif
NBLA_CUBLAS_CHECK(cublasGemmEx(handle, op_x, op_y, m, n, k, &a, x, dt, lda,
y, dt, ldb, &b, z, dt, ldc, ct,
infer_gemm_algo_by_type(dt)));
#if CUDA_VERSION < 11000
NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif
} else {
NBLA_CUBLAS_CHECK(cublasSgemmEx(handle, op_x, op_y, m, n, k, &a, x, dt, lda,
y, dt, ldb, &b, z, dt, ldc));
Expand Down Expand Up @@ -299,12 +305,17 @@ void cublas_gemm_strided_batched<half>(
float b = beta;
cudaDataType_t dt = cuda_data_type<half>::type();
cudaDataType_t ct = cuda_data_type<float>::type();
#if CUDA_VERSION < 11000
// CUBLAS_TENSOR_OP_MATH is deprecated in CUDA 11.0
NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif
NBLA_CUBLAS_CHECK(cublasGemmStridedBatchedEx(
handle, op_x, op_y, m, n, k, &a, x, dt, lda, stride_a, y, dt, ldb,
stride_b, &b, z, dt, ldc, stride_c, batch_count, ct,
infer_gemm_algo_by_type(dt)));
#if CUDA_VERSION < 11000
NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif
return;
}
#endif // CUDA_VERSION >= 9010
Expand Down
16 changes: 16 additions & 0 deletions src/nbla/cuda/init.cpp.tmpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

// Copyright 2018,2019,2020,2021 Sony Corporation.
// Copyright 2021 Sony Group Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,6 +66,15 @@ void init_cuda() {
// Init CPU features
init_cpu();

// Set NVIDIA_TF32_OVERRIDE based on NNABLA_CUDA_ALLOW_TF32.
// NNabla does not use TF32 as defalut.
const char* allow_tf32 = std::getenv("NNABLA_CUDA_ALLOW_TF32");
if (!allow_tf32 || strcmp(allow_tf32, "0") == 0) {
nbla_setenv("NVIDIA_TF32_OVERRIDE", "0");
} else {
nbla_setenv("NVIDIA_TF32_OVERRIDE", "1");
}

// Init Cuda Driver API
NBLA_CUDA_DRIVER_CHECK(cuInit(0));

Expand Down Expand Up @@ -301,6 +311,12 @@ size_t get_cuda_virtual_caching_allocator_max_available_bytes(const string& devi
vector<int> get_cuda_virtual_memory_used_counts(const string& device_id) {return {};}
#endif // CUDA_VERSION >= 10020 && CUDNN_VERSION >= 8000

bool is_cuda_tf32_enabled() {
const char* tf32_enable = std::getenv("NVIDIA_TF32_OVERRIDE");
if (!tf32_enable || strcmp(tf32_enable, "0") == 0) return false;
return true;
}

/** Get CUDA array classes.
*/
vector<string> cuda_array_classes() {
Expand Down

0 comments on commit 926a3a1

Please sign in to comment.