Skip to content

Commit

Permalink
Try to work around issue with NVHPC in conjunction of older CTK versi…
Browse files Browse the repository at this point in the history
…ons (#2889)

NVHPC can consume older CTK headers for stdpar, so we need to try and avoid using those
  • Loading branch information
miscco authored Nov 21, 2024
1 parent a39a8a7 commit 1b8151c
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions cub/cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,15 @@ struct SimdMin<__half>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __half2 operator()(__half2 a, __half2 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2half2_rn(::cuda::minimum<>{}(__half2float(a.x), __half2float(b.x)),
::cuda::minimum<>{}(__half2float(a.y), __half2float(b.y)));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(NV_PROVIDES_SM_80,
(return __hmin2(a, b);),
(return __halves2half2(__float2half(::cuda::minimum<>{}(__half2float(a.x), __half2float(b.x))),
__float2half(::cuda::minimum<>{}(__half2float(a.y), __half2float(b.y))));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand All @@ -470,11 +475,16 @@ struct SimdMin<__nv_bfloat16>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __nv_bfloat162 operator()(__nv_bfloat162 a, __nv_bfloat162 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2bfloat162_rn(::cuda::minimum<>{}(__bfloat162float(a.x), __bfloat162float(b.x)),
::cuda::minimum<>{}(__bfloat162float(a.y), __bfloat162float(b.y)));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(NV_PROVIDES_SM_80,
(return __hmin2(a, b);),
(return cub::internal::halves2bfloat162(
__float2bfloat16(::cuda::minimum<>{}(__bfloat162float(a.x), __bfloat162float(b.x))),
__float2bfloat16(::cuda::minimum<>{}(__bfloat162float(a.y), __bfloat162float(b.y))));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand Down Expand Up @@ -521,10 +531,15 @@ struct SimdMax<__half>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __half2 operator()(__half2 a, __half2 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2half2_rn(::cuda::maximum<>{}(__half2float(a.x), __half2float(b.x)),
::cuda::maximum<>{}(__half2float(a.y), __half2float(b.y)));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(NV_PROVIDES_SM_80,
(return __hmax2(a, b);),
(return __halves2half2(__float2half(::cuda::maximum<>{}(__half2float(a.x), __half2float(b.x))),
__float2half(::cuda::maximum<>{}(__half2float(a.y), __half2float(b.y))));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand All @@ -539,11 +554,16 @@ struct SimdMax<__nv_bfloat16>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __nv_bfloat162 operator()(__nv_bfloat162 a, __nv_bfloat162 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2bfloat162_rn(::cuda::maximum<>{}(__bfloat162float(a.x), __bfloat162float(b.x)),
::cuda::maximum<>{}(__bfloat162float(a.y), __bfloat162float(b.y)));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(NV_PROVIDES_SM_80,
(return __hmax2(a, b);),
(return cub::internal::halves2bfloat162(
__float2bfloat16(::cuda::maximum<>{}(__bfloat162float(a.x), __bfloat162float(b.x))),
__float2bfloat16(::cuda::maximum<>{}(__bfloat162float(a.y), __bfloat162float(b.y))));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand All @@ -566,10 +586,14 @@ struct SimdSum<__half>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __half2 operator()(__half2 a, __half2 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2half2_rn(__half2float(a.x) + __half2float(b.x), __half2float(a.y) + __half2float(b.y));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(NV_PROVIDES_SM_53,
(return __hadd2(a, b);),
(return __halves2half2(__float2half(__half2float(a.x) + __half2float(b.x)),
__float2half(__half2float(a.y) + __half2float(b.y)));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand All @@ -584,11 +608,16 @@ struct SimdSum<__nv_bfloat16>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __nv_bfloat162 operator()(__nv_bfloat162 a, __nv_bfloat162 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2bfloat162_rn(
__bfloat162float(a.x) + __bfloat162float(b.x), __bfloat162float(a.y) + __bfloat162float(b.y));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(
NV_PROVIDES_SM_80,
(return __hadd2(a, b);),
(return cub::internal::halves2bfloat162(__float2bfloat16(__bfloat162float(a.x) + __bfloat162float(b.x)),
__float2bfloat16(__bfloat162float(a.y) + __bfloat162float(b.y)));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand All @@ -611,10 +640,14 @@ struct SimdMul<__half>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __half2 operator()(__half2 a, __half2 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2half2_rn(__half2float(a.x) * __half2float(b.x), __half2float(a.y) * __half2float(b.y));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(NV_PROVIDES_SM_53,
(return __hmul2(a, b);),
(return __halves2half2(__float2half(__half2float(a.x) * __half2float(b.x)),
__float2half(__half2float(a.y) * __half2float(b.y)));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand All @@ -629,10 +662,15 @@ struct SimdMul<__nv_bfloat16>

_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE __nv_bfloat162 operator()(__nv_bfloat162 a, __nv_bfloat162 b) const
{
# if _CCCL_CUDACC_BELOW(12, 0) && defined(_CCCL_CUDA_COMPILER_NVHPC)
return __floats2bfloat162_rn(
__bfloat162float(a.x) * __bfloat162float(b.x), __bfloat162float(a.y) * __bfloat162float(b.y));
# else // ^^^ _CCCL_CUDACC_BELOW(12, 0) && _CCCL_CUDA_COMPILER_NVHPC ^^^ / vvv otherwise vvv
NV_IF_TARGET(NV_PROVIDES_SM_80,
(return __hmul2(a, b);),
(return halves2bfloat162(__float2bfloat16(__bfloat162float(a.x) * __bfloat162float(b.x)),
__float2bfloat16(__bfloat162float(a.y) * __bfloat162float(b.y)));));
# endif // !_CCCL_CUDACC_BELOW(12, 0) || !_CCCL_CUDA_COMPILER_NVHPC
}
};

Expand Down

0 comments on commit 1b8151c

Please sign in to comment.