Skip to content

Use type traits from __hip_internal namespace in HIPRTC kernels #2312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: rocm7.0_internal_testing
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions aten/src/ATen/cuda/llvm_complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ operator||(const complex<_Tp>& __x, const complex<_Tp>& __y)

// 26.3.7 values:

template <class _Tp, bool = is_integral<_Tp>::value,
bool = is_floating_point<_Tp>::value
template <class _Tp, bool = __hip_internal::is_integral<_Tp>::value,
bool = __hip_internal::is_floating_point<_Tp>::value
>
struct __libcpp_complex_overload_traits {};

Expand Down Expand Up @@ -593,9 +593,9 @@ arg(const complex<_Tp>& __c)

template<class _Tp>
inline
typename enable_if
typename __hip_internal::enable_if
<
is_integral<_Tp>::value || is_same<_Tp, double>::value,
__hip_internal::is_integral<_Tp>::value || __hip_internal::is_same<_Tp, double>::value,
double
>::type
arg(_Tp __re)
Expand All @@ -605,8 +605,8 @@ arg(_Tp __re)

template <class _Tp>
inline
typename enable_if<
is_same<_Tp, float>::value,
typename __hip_internal::enable_if<
__hip_internal::is_same<_Tp, float>::value,
float
>::type
arg(_Tp __re)
Expand Down Expand Up @@ -716,9 +716,9 @@ proj(const complex<_Tp>& __c)

template <class _Tp>
inline
typename enable_if
typename __hip_internal::enable_if
<
is_floating_point<_Tp>::value,
__hip_internal::is_floating_point<_Tp>::value,
typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
>::type
proj(_Tp __re)
Expand All @@ -730,9 +730,9 @@ proj(_Tp __re)

template <class _Tp>
inline
typename enable_if
typename __hip_internal::enable_if
<
is_integral<_Tp>::value,
__hip_internal::is_integral<_Tp>::value,
typename __libcpp_complex_overload_traits<_Tp>::_ComplexType
>::type
proj(_Tp __re)
Expand Down Expand Up @@ -866,9 +866,9 @@ pow(const complex<_Tp>& __x, const complex<_Up>& __y)

template<class _Tp, class _Up>
inline
typename enable_if
typename __hip_internal::enable_if
<
is_arithmetic<_Up>::value,
__hip_internal::is_arithmetic<_Up>::value,
complex<typename __promote<_Tp, _Up>::type>
>::type
pow(const complex<_Tp>& __x, const _Up& __y)
Expand All @@ -879,9 +879,9 @@ pow(const complex<_Tp>& __x, const _Up& __y)

template<class _Tp, class _Up>
inline
typename enable_if
typename __hip_internal::enable_if
<
is_arithmetic<_Tp>::value,
__hip_internal::is_arithmetic<_Tp>::value,
complex<typename __promote<_Tp, _Up>::type>
>::type
pow(const _Tp& __x, const complex<_Up>& __y)
Expand Down
9 changes: 2 additions & 7 deletions aten/src/ATen/native/cuda/jit_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,8 @@ template <class _Tp> struct remove_cv
{typedef typename remove_volatile<typename remove_const<_Tp>::type>::type type;};
template <class _Tp> using remove_cv_t = typename remove_cv<_Tp>::type;

template <class _Tp> struct __libcpp_is_floating_point : public false_type {};
template <> struct __libcpp_is_floating_point<float> : public true_type {};
template <> struct __libcpp_is_floating_point<double> : public true_type {};
template <> struct __libcpp_is_floating_point<long double> : public true_type {};

template <class _Tp>
inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value;
inline constexpr bool is_arithmetic_v = __hip_internal::is_arithmetic<_Tp>::value;

template <class _Tp>
struct __numeric_type
Expand All @@ -92,7 +87,7 @@ struct __numeric_type
static long double __test(long double);

typedef decltype(__test(declval<_Tp>())) type;
static const bool value = !is_same<type, void>::value;
static const bool value = !__hip_internal::is_same<type, void>::value;
};

template <>
Expand Down