diff --git a/Dockerfile b/Dockerfile index a9c0b301..8e6cbddf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN pip3 install torch==1.10.0 -i https://pypi.tuna.tsinghua.edu.cn/simple RUN pip3 install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple RUN apt install iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev -y --no-install-recommends ENV TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5 -ENV BMP_AVX512=1 +ENV BMT_AVX512=1 ADD other_requirements.txt other_requirements.txt RUN pip3 install --upgrade pip && pip3 install -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple ADD . . diff --git a/csrc/cuda/adam.cu b/csrc/cuda/adam.cu index 44ec0057..4b3df9c7 100644 --- a/csrc/cuda/adam.cu +++ b/csrc/cuda/adam.cu @@ -50,6 +50,7 @@ void adam_launcher( float bias_correction2 ) { int32_t n = param_fp32.numel(); + if (n <= 0) return; auto g_ptr = reinterpret_cast(g_fp16.data_ptr()); auto m_ptr = reinterpret_cast(m_fp16.data_ptr()); auto v_ptr = v_fp32.data_ptr(); diff --git a/csrc/cuda/has_inf_nan.cu b/csrc/cuda/has_inf_nan.cu index 0bbde439..c8bcca1a 100644 --- a/csrc/cuda/has_inf_nan.cu +++ b/csrc/cuda/has_inf_nan.cu @@ -73,6 +73,7 @@ void has_nan_inf_launcher( torch::Tensor out ) { int n = g_fp16.numel(); + if (n <= 0) return; auto g_ptr = reinterpret_cast(g_fp16.data_ptr()); auto mid_ptr = mid.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream();