diff --git a/library/src/include/utility.hpp b/library/src/include/utility.hpp index ff9351482..2cea9b505 100644 --- a/library/src/include/utility.hpp +++ b/library/src/include/utility.hpp @@ -789,6 +789,9 @@ constexpr double rocblas_internal_value_category(const T& beta) return beta == T(0) ? 0.0 : beta == T(1) ? 1.0 : beta == T(-1) ? -1.0 : 2.0; } +// Internal use +int rocblas_internal_get_arch(rocblas_handle handle); + // Internal use, whether Tensile supports ldc != ldd // We assume true if the value is greater than or equal to 906 bool rocblas_internal_tensile_supports_ldc_ne_ldd(rocblas_handle handle); diff --git a/library/src/rocblas_auxiliary.cpp b/library/src/rocblas_auxiliary.cpp index 0847b3a6b..58720f964 100644 --- a/library/src/rocblas_auxiliary.cpp +++ b/library/src/rocblas_auxiliary.cpp @@ -897,6 +897,11 @@ struct XnackMode> } }; +int rocblas_internal_get_arch(rocblas_handle handle) +{ + return handle->getArch(); +} + bool rocblas_internal_tensile_supports_ldc_ne_ldd(rocblas_handle handle) { return handle->getArch() >= 906; diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp index 89526e050..946b157ee 100644 --- a/library/src/tensile_host.cpp +++ b/library/src/tensile_host.cpp @@ -1081,6 +1081,15 @@ template bool useHipBLASLt(const RocblasContractionProblem& prob) { #ifdef BUILD_WITH_HIPBLASLT + if constexpr(sizeof(Ti) >= 4) + { + // TODO remove after tuning + if(rocblas_internal_get_arch(prob.handle) == 950) + { + return false; + } + } + bool batched = prob.batch_A != nullptr; return prob.handle->tryHipBLASLt(batched); #else