From a4633b418d4b65432d0eaf6bad0f7084a0d365ac Mon Sep 17 00:00:00 2001 From: Satyanvesh Dittakavi Date: Thu, 3 Jul 2025 09:21:17 +0000 Subject: [PATCH] SWDEV-541185 - use type traits from __hip_internal namespace in hipRTC kernels --- aten/src/ATen/cuda/llvm_complex.cpp | 28 ++++++++++++------------- aten/src/ATen/native/cuda/jit_utils.cpp | 9 ++------ 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/aten/src/ATen/cuda/llvm_complex.cpp b/aten/src/ATen/cuda/llvm_complex.cpp index 9caea9d69f05..a38054cd8e6f 100644 --- a/aten/src/ATen/cuda/llvm_complex.cpp +++ b/aten/src/ATen/cuda/llvm_complex.cpp @@ -514,8 +514,8 @@ operator||(const complex<_Tp>& __x, const complex<_Tp>& __y) // 26.3.7 values: -template ::value, - bool = is_floating_point<_Tp>::value +template ::value, + bool = __hip_internal::is_floating_point<_Tp>::value > struct __libcpp_complex_overload_traits {}; @@ -593,9 +593,9 @@ arg(const complex<_Tp>& __c) template inline -typename enable_if +typename __hip_internal::enable_if < - is_integral<_Tp>::value || is_same<_Tp, double>::value, + __hip_internal::is_integral<_Tp>::value || __hip_internal::is_same<_Tp, double>::value, double >::type arg(_Tp __re) @@ -605,8 +605,8 @@ arg(_Tp __re) template inline -typename enable_if< - is_same<_Tp, float>::value, +typename __hip_internal::enable_if< + __hip_internal::is_same<_Tp, float>::value, float >::type arg(_Tp __re) @@ -716,9 +716,9 @@ proj(const complex<_Tp>& __c) template inline -typename enable_if +typename __hip_internal::enable_if < - is_floating_point<_Tp>::value, + __hip_internal::is_floating_point<_Tp>::value, typename __libcpp_complex_overload_traits<_Tp>::_ComplexType >::type proj(_Tp __re) @@ -730,9 +730,9 @@ proj(_Tp __re) template inline -typename enable_if +typename __hip_internal::enable_if < - is_integral<_Tp>::value, + __hip_internal::is_integral<_Tp>::value, typename __libcpp_complex_overload_traits<_Tp>::_ComplexType >::type proj(_Tp __re) @@ -866,9 +866,9 @@ pow(const complex<_Tp>& __x, const complex<_Up>& __y) template inline -typename enable_if +typename __hip_internal::enable_if < - is_arithmetic<_Up>::value, + __hip_internal::is_arithmetic<_Up>::value, complex::type> >::type pow(const complex<_Tp>& __x, const _Up& __y) @@ -879,9 +879,9 @@ pow(const complex<_Tp>& __x, const _Up& __y) template inline -typename enable_if +typename __hip_internal::enable_if < - is_arithmetic<_Tp>::value, + __hip_internal::is_arithmetic<_Tp>::value, complex::type> >::type pow(const _Tp& __x, const complex<_Up>& __y) diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 0d49ec9c187c..8abde9db2fa2 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -68,13 +68,8 @@ template struct remove_cv {typedef typename remove_volatile::type>::type type;}; template using remove_cv_t = typename remove_cv<_Tp>::type; -template struct __libcpp_is_floating_point : public false_type {}; -template <> struct __libcpp_is_floating_point : public true_type {}; -template <> struct __libcpp_is_floating_point : public true_type {}; -template <> struct __libcpp_is_floating_point : public true_type {}; - template -inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value; +inline constexpr bool is_arithmetic_v = __hip_internal::is_arithmetic<_Tp>::value; template struct __numeric_type @@ -92,7 +87,7 @@ struct __numeric_type static long double __test(long double); typedef decltype(__test(declval<_Tp>())) type; - static const bool value = !is_same::value; + static const bool value = !__hip_internal::is_same::value; }; template <>