From 5d465eaf74cb0e49acc87d693b1bc7b113ff4891 Mon Sep 17 00:00:00 2001 From: nnabla developer Date: Fri, 10 Sep 2021 13:07:26 +0900 Subject: [PATCH] disable tf32 as default, deprecate CUBLAS_TENSOR_OP_MATH --- include/nbla/cuda/init.hpp | 6 +++ python/src/nnabla_ext/cuda/init.pyx | 6 +++ python/test/cuda/test_allow_tf32.py | 55 ++++++++++++++++++++++++ python/test/cuda/test_allow_tf32_main.py | 46 ++++++++++++++++++++ readme.md | 5 +++ src/nbla/cuda/cublas.cpp | 11 +++++ src/nbla/cuda/init.cpp.tmpl | 16 +++++++ 7 files changed, 145 insertions(+) create mode 100644 python/test/cuda/test_allow_tf32.py create mode 100644 python/test/cuda/test_allow_tf32_main.py diff --git a/include/nbla/cuda/init.hpp b/include/nbla/cuda/init.hpp index e6f70d5b03..d4734f396b 100644 --- a/include/nbla/cuda/init.hpp +++ b/include/nbla/cuda/init.hpp @@ -1,4 +1,5 @@ // Copyright 2017,2018,2019,2020,2021 Sony Corporation. +// Copyright 2021 Sony Group Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -70,6 +71,11 @@ get_cuda_virtual_caching_allocator_max_available_bytes(const string &device_id); NBLA_CUDA_API vector get_cuda_virtual_memory_used_counts(const string &device_id); +/** + * Check if tf32 is enabled or not. + */ +NBLA_CUDA_API bool is_cuda_tf32_enabled(); + /** Get CUDA array classes. */ NBLA_CUDA_API vector cuda_array_classes(); diff --git a/python/src/nnabla_ext/cuda/init.pyx b/python/src/nnabla_ext/cuda/init.pyx index 5b9d0f9f80..fd91fe1a7d 100644 --- a/python/src/nnabla_ext/cuda/init.pyx +++ b/python/src/nnabla_ext/cuda/init.pyx @@ -1,4 +1,5 @@ # Copyright (c) 2017 Sony Corporation. All Rights Reserved. +# Copyright 2021 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,6 +37,7 @@ cdef extern from "nbla/cuda/init.hpp" namespace "nbla": size_t get_cuda_virtual_caching_allocator_fragmentation_bytes(const string& device_id) except + size_t get_cuda_virtual_caching_allocator_max_available_bytes(const string& device_id) except + vector[int] get_cuda_virtual_memory_used_counts(const string& device_id) except + + bool is_cuda_tf32_enabled() except+ vector[string] cuda_array_classes() except + void _cuda_set_array_classes(const vector[string] & a) except + void cuda_device_synchronize(const string & device) except + @@ -109,6 +111,10 @@ def get_virtual_memory_used_counts(str device_id): """Get # of cuda virtual memory which is currently used.""" return get_cuda_virtual_memory_used_counts(device_id) +def is_tf32_enabled(): + """Check if tf32 is enabled or not.""" + return is_cuda_tf32_enabled() + ############################################################################### # Array preference API # TODO: Move these to C++ diff --git a/python/test/cuda/test_allow_tf32.py b/python/test/cuda/test_allow_tf32.py new file mode 100644 index 0000000000..84bd44f4dc --- /dev/null +++ b/python/test/cuda/test_allow_tf32.py @@ -0,0 +1,55 @@ +# Copyright 2017,2018,2019,2020,2021 Sony Corporation. +# Copyright 2021 Sony Group Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import pytest +import os +import subprocess + + +def teardown_function(function): + if os.getenv('NNABLA_CUDA_ALLOW_TF32'): + del os.environ["NNABLA_CUDA_ALLOW_TF32"] + if os.getenv('NVIDIA_TF32_OVERRIDE'): + del os.environ["NVIDIA_TF32_OVERRIDE"] + + +''' +This test is verify that tf32 is set correctly by the environment variables. +''' + + +@pytest.mark.parametrize("nnabla_tf32", [None, "0", "1"]) +@pytest.mark.parametrize("nvidia_tf32", [None, "0", "1"]) +@pytest.mark.parametrize("context", ['cuda', 'cudnn']) +def test_allow_tf32(nnabla_tf32, nvidia_tf32, context): + + if nnabla_tf32: + os.environ["NNABLA_CUDA_ALLOW_TF32"] = nnabla_tf32 + if nvidia_tf32: + os.environ["NVIDIA_TF32_OVERRIDE"] = nvidia_tf32 + + d = os.path.dirname(os.path.abspath(__file__)) + test_main_path = d + "/test_allow_tf32_main.py" + + if nnabla_tf32 == "1": + result = subprocess.run( + ["python", test_main_path, "--context", context, "--allow-tf32"]) + else: + result = subprocess.run( + ["python", test_main_path, "--context", context]) + + assert result.returncode == 0 diff --git a/python/test/cuda/test_allow_tf32_main.py b/python/test/cuda/test_allow_tf32_main.py new file mode 100644 index 0000000000..8c8cfd0fdb --- /dev/null +++ b/python/test/cuda/test_allow_tf32_main.py @@ -0,0 +1,46 @@ +# Copyright 2017,2018,2019,2020,2021 Sony Corporation. +# Copyright 2021 Sony Group Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import nnabla as nn +import os +import argparse + +''' +This test is called by test_allow_tf32.py to verify that tf32 is set correctly by the environment variables. +In pytest, init_cuda is executed only once throughout the whole tests, and this test case is added as a separate test because the correct behavior could not be confirmed. +''' + + +def main(args): + context = args.context + + from nnabla.ext_utils import get_extension_context + with nn.context_scope(get_extension_context(context)): + import nnabla_ext.cuda.init as cuda_init + allow_tf32 = cuda_init.is_tf32_enabled() + + assert args.allow_tf32 == allow_tf32 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--allow-tf32", action='store_true', default=False) + parser.add_argument("--context", type=str, + default="cuda", choices=["cudnn", "cuda"]) + args = parser.parse_args() + main(args) diff --git a/readme.md b/readme.md index b7f4dd8a77..2df200c618 100644 --- a/readme.md +++ b/readme.md @@ -37,6 +37,11 @@ However, it often consumes much memory due to a big workspace memory required by In some cases it may be desired to restrict the automatic search for CUDNN Convolution algorithms to those that give deterministic (reproducable) results. This can be achived by setting an environment variable `NNABLA_CUDNN_DETERMINISTIC` to some value other than `0`. +### TensorFloat-32 (TF32) + +In NNabla, the environment variable `NNABLA_CUDA_ALLOW_TF32` controls whether TF32 (about TF32, see [a blog post](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) from NVIDIA) is allowed to be used. If `NNABLA_CUDA_ALLOW_TF32` is not set (default) or 0, TF32 is disabled. Otherwise, it is enabled. `NNABLA_CUDA_ALLOW_TF32` always takes priority of `NVIDIA_TF32_OVERRIDE`. `NNABLA_CUDA_ALLOW_TF32` is only evaluated when initializing NNabla CUDA extension. If it is changed within the user program, the behavior is undefined. + + ## FAQ No FAQ so far. diff --git a/src/nbla/cuda/cublas.cpp b/src/nbla/cuda/cublas.cpp index 297449c2d9..73a8f98ddc 100644 --- a/src/nbla/cuda/cublas.cpp +++ b/src/nbla/cuda/cublas.cpp @@ -1,4 +1,5 @@ // Copyright 2017,2018,2019,2020,2021 Sony Corporation. +// Copyright 2021 Sony Group Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -68,11 +69,16 @@ void cublas_gemm(cublasHandle_t handle, cublasOperation_t op_x, cudaDeviceProp prop = cuda_get_current_device_properties(); if (prop.major >= 5) { auto ct = cuda_data_type::type>::type(); +#if CUDA_VERSION < 11000 + // CUBLAS_TENSOR_OP_MATH is deprecated in CUDA 11.0 NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); +#endif NBLA_CUBLAS_CHECK(cublasGemmEx(handle, op_x, op_y, m, n, k, &a, x, dt, lda, y, dt, ldb, &b, z, dt, ldc, ct, infer_gemm_algo_by_type(dt))); +#if CUDA_VERSION < 11000 NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); +#endif } else { NBLA_CUBLAS_CHECK(cublasSgemmEx(handle, op_x, op_y, m, n, k, &a, x, dt, lda, y, dt, ldb, &b, z, dt, ldc)); @@ -299,12 +305,17 @@ void cublas_gemm_strided_batched( float b = beta; cudaDataType_t dt = cuda_data_type::type(); cudaDataType_t ct = cuda_data_type::type(); +#if CUDA_VERSION < 11000 + // CUBLAS_TENSOR_OP_MATH is deprecated in CUDA 11.0 NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); +#endif NBLA_CUBLAS_CHECK(cublasGemmStridedBatchedEx( handle, op_x, op_y, m, n, k, &a, x, dt, lda, stride_a, y, dt, ldb, stride_b, &b, z, dt, ldc, stride_c, batch_count, ct, infer_gemm_algo_by_type(dt))); +#if CUDA_VERSION < 11000 NBLA_CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); +#endif return; } #endif // CUDA_VERSION >= 9010 diff --git a/src/nbla/cuda/init.cpp.tmpl b/src/nbla/cuda/init.cpp.tmpl index 8fea4bdfcb..d9add6fa57 100644 --- a/src/nbla/cuda/init.cpp.tmpl +++ b/src/nbla/cuda/init.cpp.tmpl @@ -1,5 +1,6 @@ // Copyright 2018,2019,2020,2021 Sony Corporation. +// Copyright 2021 Sony Group Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -65,6 +66,15 @@ void init_cuda() { // Init CPU features init_cpu(); + // Set NVIDIA_TF32_OVERRIDE based on NNABLA_CUDA_ALLOW_TF32. + // NNabla does not use TF32 as defalut. + const char* allow_tf32 = std::getenv("NNABLA_CUDA_ALLOW_TF32"); + if (!allow_tf32 || strcmp(allow_tf32, "0") == 0) { + nbla_setenv("NVIDIA_TF32_OVERRIDE", "0"); + } else { + nbla_setenv("NVIDIA_TF32_OVERRIDE", "1"); + } + // Init Cuda Driver API NBLA_CUDA_DRIVER_CHECK(cuInit(0)); @@ -301,6 +311,12 @@ size_t get_cuda_virtual_caching_allocator_max_available_bytes(const string& devi vector get_cuda_virtual_memory_used_counts(const string& device_id) {return {};} #endif // CUDA_VERSION >= 10020 && CUDNN_VERSION >= 8000 +bool is_cuda_tf32_enabled() { + const char* tf32_enable = std::getenv("NVIDIA_TF32_OVERRIDE"); + if (!tf32_enable || strcmp(tf32_enable, "0") == 0) return false; + return true; +} + /** Get CUDA array classes. */ vector cuda_array_classes() {