diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 18731e810e..03c531c1ad 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 87812369bd..5167097b6d 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index c3e6ef7d5d..abf7ef3905 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 93034a8b70..2582ea8a11 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index e7c1d6f0be..57e2feb084 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index d1e1a51afd..2d9c794faa 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -70,3 +70,5 @@ example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpresh example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) + +add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp new file mode 100644 index 0000000000..54abab2f60 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct AddAdd +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c + d0 + d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3 + // clang-format off + //#########################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| S| | | + < A0Layout, B0Layout, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD, StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp index 580f38a79f..086ea45d10 100644 --- a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -184,7 +184,6 @@ int main(int argc, char* argv[]) b0_device_buf.ToDevice(b0_k_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -220,11 +219,12 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50}); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -233,8 +233,6 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - if(do_verification) { Tensor c_m_n({M, N}); diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt new file mode 100644 index 0000000000..af091d32e4 --- /dev/null +++ b/example/68_gemm_add/CMakeLists.txt @@ -0,0 +1,22 @@ +add_custom_target(example_gemm_add_xdl) + +add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_fp16) + + +add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_bf16) + +add_custom_target(example_gemm_add_wmma) + +add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp) +add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_bf16) + +add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp) +add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_fp16) + + + + + + diff --git a/example/68_gemm_add/common.hpp b/example/68_gemm_add/common.hpp new file mode 100644 index 0000000000..38e77a160f --- /dev/null +++ b/example/68_gemm_add/common.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Row_Tuple = ck::Tuple; +using F16_Tuple = ck::Tuple; +using BF16_Tuple = ck::Tuple; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; +}; +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideD = std::stoi(argv[9]); + problem_size.StrideE = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD," + "StrideE" + << std::endl; + return false; + } + + return true; +} diff --git a/example/68_gemm_add/gemm_add_wmma_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp new file mode 100644 index 0000000000..30f0aa9153 --- /dev/null +++ b/example/68_gemm_add/gemm_add_wmma_bf16.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using DsDataType = BF16_Tuple; +using EDataType = BF16; + +using Row_Tuple = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + BF16, + BF16, + BF16_Tuple, + BF16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_example_wmma.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_wmma_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_fp16.cpp new file mode 100644 index 0000000000..caf245bf76 --- /dev/null +++ b/example/68_gemm_add/gemm_add_wmma_fp16.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = F16_Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + F16, + F16, + F16_Tuple, + F16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_example_wmma.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_xdl_bf16.cpp b/example/68_gemm_add/gemm_add_xdl_bf16.cpp new file mode 100644 index 0000000000..284e424c14 --- /dev/null +++ b/example/68_gemm_add/gemm_add_xdl_bf16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gemm_add_example_xdl.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_xdl_fp16.cpp b/example/68_gemm_add/gemm_add_xdl_fp16.cpp new file mode 100644 index 0000000000..4ba10e9d3b --- /dev/null +++ b/example/68_gemm_add/gemm_add_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gemm_add_example_xdl.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/run_gemm_add_example_wmma.inc b/example/68_gemm_add/run_gemm_add_example_wmma.inc new file mode 100644 index 0000000000..7a6c8ea56d --- /dev/null +++ b/example/68_gemm_add/run_gemm_add_example_wmma.inc @@ -0,0 +1,145 @@ +#pragma once + +bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + bool pass = true; + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + + return pass; +} + +bool run_gemm_add_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config); +} diff --git a/example/68_gemm_add/run_gemm_add_example_xdl.inc b/example/68_gemm_add/run_gemm_add_example_xdl.inc new file mode 100644 index 0000000000..97c0765c27 --- /dev/null +++ b/example/68_gemm_add/run_gemm_add_example_xdl.inc @@ -0,0 +1,144 @@ +#pragma once + +bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + bool pass = true; + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + + return pass; +} + +bool run_gemm_add_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config); +} diff --git a/example/69_gemm_add_relu/CMakeLists.txt b/example/69_gemm_add_relu/CMakeLists.txt new file mode 100644 index 0000000000..9ab3ef5a45 --- /dev/null +++ b/example/69_gemm_add_relu/CMakeLists.txt @@ -0,0 +1,15 @@ +add_custom_target(example_gemm_add_relu_xdl) + +add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_fp16) + +add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_bf16) + +add_custom_target(example_gemm_add_relu_wmma) + +add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.cpp) +add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_bf16) + +add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp) +add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_fp16) diff --git a/example/69_gemm_add_relu/common.hpp b/example/69_gemm_add_relu/common.hpp new file mode 100644 index 0000000000..311cbb2dfb --- /dev/null +++ b/example/69_gemm_add_relu/common.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Row_Tuple = ck::Tuple; +using F16_Tuple = ck::Tuple; +using BF16_Tuple = ck::Tuple; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; +}; +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideD = std::stoi(argv[9]); + problem_size.StrideE = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD," + "StrideE" + << std::endl; + return false; + } + + return true; +} diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp new file mode 100644 index 0000000000..5c4116cc44 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using DsDataType = BF16_Tuple; +using EDataType = BF16; + +using Row_Tuple = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + BF16, + BF16, + BF16_Tuple, + BF16, + F32, + F32, + PassThrough, + PassThrough, + AddRelu, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_relu_example_wmma.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp new file mode 100644 index 0000000000..07f5197d21 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = F16_Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + F16, + F16, + F16_Tuple, + F16, + F32, + F32, + PassThrough, + PassThrough, + AddRelu, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_relu_example_wmma.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp new file mode 100644 index 0000000000..b5a84cd828 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gemm_add_relu_example_xdl.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp new file mode 100644 index 0000000000..9e91641ba4 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gemm_add_relu_example_xdl.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc new file mode 100644 index 0000000000..27bd4de48d --- /dev/null +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc @@ -0,0 +1,146 @@ +#pragma once + +bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + bool pass = true; + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + + return pass; +} + +bool run_gemm_add_relu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && + run_gemm_add_relu(problem_size, config); +} diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc new file mode 100644 index 0000000000..e2d45fca43 --- /dev/null +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc @@ -0,0 +1,145 @@ +#pragma once + +bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + bool pass = true; + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + + return pass; +} + +bool run_gemm_add_relu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && + run_gemm_add_relu(problem_size, config); +} diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp old mode 100755 new mode 100644 index 48306e35fe..c0e4dc3d30 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -525,8 +525,8 @@ bool run(const ck_tile::ArgParser& arg_parser) flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); - num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + - sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(ODataType) * real_seqlen_q * hdim_v); num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + sizeof(VDataType) * hdim_v * real_seqlen_k); } diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp old mode 100755 new mode 100644 index b96482f535..af38ff0214 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 6ebdbc5054..3dff1b28c6 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -149,50 +149,105 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator #endif }; +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD and DeviceGemmMultipleDSplitK is +/// that DeviceGemmMultipleDSplitK::MakeArgumentPointer requires an additional parameter +/// KBatch which is explicitly passed as 1 by this wrapper. template -struct DeviceMoEGemmMXBPreShuffle : public BaseOperator +struct DeviceGemmMultipleDSplitKWrapper : public DeviceGemmMultipleD { + using DeviceOp = DeviceGemmMultipleDSplitK; + static constexpr index_t NumDTensor = DsDataType::Size(); -#ifndef CK_CODE_GEN_RTC - virtual std::unique_ptr +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_a_scale, const void* p_b, - const void* p_b_scale, std::array p_ds, void* p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, - ck::index_t StrideAScale, ck::index_t StrideB, - ck::index_t StrideBScale, std::array StrideDs, ck::index_t StrideE, - ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; - - virtual int GetPreShuffleParameters() = 0; -#endif + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, // KBatch + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index 5d983afb9b..c00078186f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -40,7 +40,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if(defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; + using c_data_type = remove_cvref_t>; if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && (std::is_same_v || std::is_same_v))) @@ -62,14 +62,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, p_shared, - karg); + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op); #if defined(__gfx11__) } #endif @@ -272,11 +276,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm, // DsLayout CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, + Tuple<>, // DsDataType CDataType, AElementwiseOperation, BElementwiseOperation, @@ -311,7 +317,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -336,17 +342,25 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm{}, // p_ds_grid_ p_c_grid_, M_, N_, K_, StrideA_, StrideB_, + std::array{}, // StrideDs_ StrideC_, k_batch_, + a_element_op_, + b_element_op_, + cde_element_op_, is_reduce_), Batch(Batch_), compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_} @@ -443,7 +457,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...) +/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam DsLayout D tensors data layouts. +/// @tparam ELayout E tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam DsDataType D tensors data types. +/// @tparam EDataType E tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemmMultipleD_Wmma_CShuffleV3 + : public DeviceGemmMultipleDSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0]; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + str << std::string(DLayout::name)[0]; + }); + str << std::string(ELayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< { - // GridwiseGemm using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, + Tuple<>, // DsLayout CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, + Tuple<>, // DsDataType CDataType, AElementwiseOperation, BElementwiseOperation, @@ -220,7 +221,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -230,21 +231,24 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2; + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; // Invoker using Invoker = typename DeviceGemmCommon::Invoker; @@ -275,11 +279,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2{}, // p_ds_grid_ + p_c, + M, + N, + K, + StrideA, + StrideB, + std::array{}, // StrideDs_ + StrideC, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -295,20 +313,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2(static_cast(p_a), static_cast(p_b), + std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, StrideA, StrideB, + std::array{}, // StrideDs_ StrideC, - KBatch); + KBatch, + a_element_op, + b_element_op, + c_element_op); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp index 1a68b35f1f..a9d5c666a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -89,11 +89,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, // DsLayout CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, + Tuple<>, // DsDataType CDataType, AElementwiseOperation, BElementwiseOperation, @@ -130,7 +132,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -140,21 +142,24 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale; + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; // Invoker using Invoker = typename DeviceGemmCommon::Invoker; @@ -188,23 +193,25 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{}, // p_ds_grid_ p_c, M, N, K, StrideA, StrideB, + std::array{}, // StrideDs_ StrideC, StrideScaleB, p_b_scale, KBatch, a_element_op, b_element_op, - c_element_op}; + cde_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -228,12 +235,14 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale(static_cast(p_a), static_cast(p_b), + std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, StrideA, StrideB, + std::array{}, // StrideDs_ StrideC, StrideScaleB, static_cast(p_b_scale), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 24b96a1e60..55aa7b59ee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -24,7 +24,8 @@ namespace device { template rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + std::array size_ds_buffers; + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + size_ds_buffers[i] = + ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + size_ds_buffers); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -106,9 +122,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common rotating_mem.Next(); // clear c mem if(arg_.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, 0, - arg_.M * arg_.N * sizeof(CDataType), + arg_.M * arg_.N * sizeof(EDataType), stream_config.stream_id_)); }; @@ -124,9 +140,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common else { if(arg.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, 0, - arg.M * arg.N * sizeof(CDataType), + arg.M * arg.N * sizeof(EDataType), stream_config.stream_id_)); ave_time = launch_and_time_kernel( @@ -149,6 +165,16 @@ struct DeviceGemm_Wmma_CShuffleV3_Common } }(); + // ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is + // currently implemented in such a way that all SrcScalarPerVectors must be the same, so + // if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the + // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot + // be odd. + constexpr bool AtomicsImplementationExists = + !(std::is_same_v || + std::is_same_v) || + (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0); + if(has_main_k_block_loop) { // Tail number always full @@ -157,12 +183,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common { if(arg.KBatch > 1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if constexpr(AtomicsImplementationExists) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } } else { @@ -186,12 +215,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common { if(arg.KBatch > 1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if constexpr(AtomicsImplementationExists) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } } else { @@ -229,8 +261,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common return false; } - if constexpr(std::is_same_v || - std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { if(arg.KBatch > 1 && ck::is_gfx11_supported()) { diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index d86f01e255..61d249fc93 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -47,7 +47,7 @@ struct Add __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const { - y = type_convert(x0) + x1; + y = x0 + type_convert(x1); }; template <> diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 9a8d09e5e4..bd2a8b04bc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" @@ -22,9 +22,10 @@ namespace ck { /// /// @par Overview /// This GEMM kernel is carrying out following mathematical equation: -/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) -/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are -/// elementwise operations that could be applied on each tensor respectively. +/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...) +/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. /// The \"universal\" gemm comes with multiple pipelines optimized for different usage /// scenarios. That's why it's called \"universal\". It's universal through it's design /// and versatilty. @@ -36,18 +37,20 @@ namespace ck { /// /// @tparam ALayout A tensor data layout. /// @tparam BLayout B tensor data layout. -/// @tparam CLayout C tensor data layout. +/// @tparam DsLayout D tensors data layouts. +/// @tparam ELayout E tensor data layout. /// @tparam ADataType A tensor data type. /// @tparam BDataType B tensor data type. /// @tparam AccDataType The accumulation data type related to the hardware /// matrix-multiplication instruction. /// @tparam CShuffleDataType The data type used to store matrix-multiplication results into /// LDS memory during \"CShuffle\" data layout optimization. -/// @tparam CDataType C tensor data type. -/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. -/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. -/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor -/// (after GEMM). +/// @tparam DsDataType D tensors data types. +/// @tparam EDataType E tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. /// @tparam GemmSpec Determines used "padding" version. /// @tparam BlockSize The number of threads within workgroup. /// @tparam MPerBlock The input/output data tile size in the M dimension. @@ -105,11 +108,12 @@ namespace ck { /// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions /// results to process per wave per iteration of CShuffle /// in N dimension. -/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial /// thread distribution used for storing data into output /// tensor across output data layout dimensions. -/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. -/// Used when storing data to output tensor. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. /// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or /// intrawave). /// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. @@ -123,15 +127,17 @@ namespace ck { /// in global memory (pre-shuffled). template ; using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::NumDTensor; + using typename Base::DsGridPointer; + struct Problem { __host__ Problem(index_t M_, @@ -315,14 +330,16 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t KBatch_) : M{M_}, N{N_}, K{K_}, StrideA{StrideA_}, StrideB{StrideB_}, - StrideC{StrideC_}, + StrideDs{StrideDs_}, + StrideE{StrideE_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, @@ -338,11 +355,19 @@ struct GridwiseGemm_wmma_cshuffle_v3 __host__ void Print() const { std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC - << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 - << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", "; + if constexpr(NumDTensor > 0) + { + std::cout << "SDs: { "; + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : ""); + }); + std::cout << " }, "; + } + std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -350,7 +375,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t K; index_t StrideA; index_t StrideB; - index_t StrideC; + std::array StrideDs; + index_t StrideE; index_t KBatch; index_t MPadded; index_t NPadded; @@ -367,21 +393,35 @@ struct GridwiseGemm_wmma_cshuffle_v3 { __host__ Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, - CDataType* p_c_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, - p_c_grid{p_c_grid_}, + p_ds_grid{}, + p_e_grid{p_e_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); } __host__ __device__ inline bool IsReduceAdd() const @@ -396,42 +436,49 @@ struct GridwiseGemm_wmma_cshuffle_v3 const ADataType* p_a_grid; const BDataType* p_b_grid; - CDataType* p_c_grid; + DsGridPointer p_ds_grid; + EDataType* p_e_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CDEElementwiseOperation cde_element_op; + + // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd bool is_reduce; }; struct SplitKBatchOffset { - __device__ SplitKBatchOffset(Argument& karg) + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + a_k_split_offset = k_id * karg.KRead / APackedSize; } else if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; } if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + b_k_split_offset = k_id * karg.KRead / BPackedSize; } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + b_k_split_offset = k_id * k0_offset / BPackedSize; } } - if(blockIdx.z < static_cast(karg.KBatch - 1)) + if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; } @@ -442,7 +489,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 if(karg.IsReduceAdd()) { - c_reduce_offset = blockIdx.z * karg.M * karg.N; + c_reduce_offset = k_id * karg.M * karg.N; } else { @@ -465,23 +512,32 @@ struct GridwiseGemm_wmma_cshuffle_v3 __device__ static index_t GetKBlockPerScale() { return 1; } template + InMemoryDataOperationEnum EGlobalMemoryDataOperation, + TailNumber TailNum> __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, - CDataType* p_c_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, problem.MBlock, problem.NBlock); // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -491,8 +547,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } @@ -508,17 +564,23 @@ struct GridwiseGemm_wmma_cshuffle_v3 Base::template Run(p_a_grid, p_b_grid, - p_c_grid, + p_ds_grid, + p_e_grid, p_shared, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, block_m_id, block_n_id, num_k_block_per_scale, @@ -528,17 +590,21 @@ struct GridwiseGemm_wmma_cshuffle_v3 // Wrapper function to have __global__ function in common // between gemm_universal, b_scale, ab_scale, etc. template + InMemoryDataOperationEnum EGlobalMemoryDataOperation, + TailNumber TailNum> __device__ static void - Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) { - Run( + Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset, p_shared, - karg); + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index 37ffbf1c51..29c5ae31cd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -20,15 +20,17 @@ namespace ck { template @@ -72,15 +74,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, - CLayout, + DsLayout, + ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, - CDataType, + DsDataType, + EDataType, AElementwiseOperation, BElementwiseOperation, - CElementwiseOperation, + CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, @@ -110,8 +114,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -124,15 +128,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, - CLayout, + DsLayout, + ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, - CDataType, + DsDataType, + EDataType, AElementwiseOperation, BElementwiseOperation, - CElementwiseOperation, + CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, @@ -162,8 +168,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -198,17 +204,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Base::CalculateNPadded; using Base::MakeAGridDescriptor_AK0_M_AK1; using Base::MakeBGridDescriptor_BK0_N_BK1; - using Base::MakeCGridDescriptor_M_N; + using Base::MakeDEGridDescriptor_M_N; + using Base::MakeDsGridDescriptor_M_N; + using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; - using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; + using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; using ThisThreadBlock = ThisThreadBlock; using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::NumDTensor; + using typename Base::DsGridPointer; + struct Problem { __host__ Problem(index_t M_, @@ -216,7 +227,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t StrideScaleB_, index_t KBatch_) : M{M_}, @@ -224,7 +236,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale K{K_}, StrideA{StrideA_}, StrideB{StrideB_}, - StrideC{StrideC_}, + StrideDs{StrideDs_}, + StrideE{StrideE_}, StrideScaleB{StrideScaleB_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, @@ -241,11 +254,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __host__ void Print() const { std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC - << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded - << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", "; + if constexpr(NumDTensor > 0) + { + std::cout << "SDs: { "; + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : ""); + }); + std::cout << " }, "; + } + std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", " + << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead + << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 + << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" + << std::endl; } index_t M; @@ -253,7 +275,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale index_t K; index_t StrideA; index_t StrideB; - index_t StrideC; + std::array StrideDs; + index_t StrideE; index_t StrideScaleB; index_t KBatch; index_t MPadded; @@ -271,30 +294,38 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale { __host__ Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, - CDataType* p_c_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t StrideScaleB_, const BScaleType* p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_, + CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, - p_c_grid{p_c_grid_}, + p_ds_grid{}, + p_e_grid{p_e_grid_}, p_b_scale_grid{p_b_scale_grid_}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, - c_element_op{c_element_op_}, + cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); } __host__ __device__ inline bool IsReduceAdd() const @@ -309,57 +340,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale const ADataType* p_a_grid; const BDataType* p_b_grid; - CDataType* p_c_grid; + DsGridPointer p_ds_grid; + EDataType* p_e_grid; const BScaleType* p_b_scale_grid; const AElementwiseOperation a_element_op; const BElementwiseOperation b_element_op; - const CElementwiseOperation c_element_op; + const CDEElementwiseOperation cde_element_op; bool is_reduce; }; struct SplitKBatchOffset { - __device__ SplitKBatchOffset(Argument& karg) + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + a_k_split_offset = k_id * karg.KRead / APackedSize; } else if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; } if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + b_k_split_offset = k_id * karg.KRead / BPackedSize; } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + b_k_split_offset = k_id * k0_offset / BPackedSize; } } // Calculate B scale offset if constexpr(is_same_v) { - scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB; + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB; } else if constexpr(is_same_v) { - scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK); } - if(blockIdx.z < static_cast(karg.KBatch - 1)) + if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; } @@ -370,7 +402,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale if(karg.IsReduceAdd()) { - c_reduce_offset = blockIdx.z * karg.M * karg.N; + c_reduce_offset = k_id * karg.M * karg.N; } else { @@ -454,24 +486,33 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale } template + InMemoryDataOperationEnum EGlobalMemoryDataOperation, + TailNumber TailNum> __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, - CDataType* p_c_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, const BScaleType* p_b_scale_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, problem.MBlock, problem.NBlock); // B Scale grid const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( @@ -487,8 +528,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } @@ -503,17 +544,23 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale Base::template Run(p_a_grid, p_b_grid, - p_c_grid, + p_ds_grid, + p_e_grid, p_shared, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, block_m_id, block_n_id, num_k_block_per_scale, @@ -523,18 +570,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale // NOTE: Wrapper function to have __global__ function in common // between gemm_universal, b_scale, ab_scale, etc. template + InMemoryDataOperationEnum EGlobalMemoryDataOperation, + TailNumber TailNum> __device__ static void - Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) { - Run( + Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset, karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, p_shared, - karg); + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index c60dba3b48..f779909e87 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -19,7 +19,7 @@ namespace ck { template __global__ void @@ -31,17 +31,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if(defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) { #endif __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg); #if defined(__gfx11__) @@ -54,15 +54,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) template {}; static constexpr auto I7 = Number<7>{}; + static constexpr auto EShuffleBlockTransferScalarPerVector = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> static constexpr auto AK0Number = Number{}; static constexpr auto BK0Number = Number{}; @@ -430,17 +435,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); } + template __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE) { const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE)); } }(); @@ -493,6 +499,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base #endif } + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDEGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -805,18 +849,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base NRepeat, KPack>())>; - template - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + template + __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) { - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, + const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + de_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), make_unmerge_transform(make_tuple(NBlock, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - return c_grid_desc_mblock_mperblock_nblock_nperblock; + return de_grid_desc_mblock_mperblock_nblock_nperblock; } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -950,56 +994,51 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } - if constexpr(is_same::value) + if constexpr(is_same::value) { - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + if(karg.N % EShuffleBlockTransferScalarPerVector != 0) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + "EShuffleBlockTransferScalarPerVector (" + << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } } else { - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + if(karg.M % EShuffleBlockTransferScalarPerVector != 0) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + "EShuffleBlockTransferScalarPerVector (" + << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } } - if constexpr(!(is_same, half_t>::value || - is_same, float>::value || - is_same, bhalf_t>::value || - is_same, int32_t>::value)) + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) { - if(!karg.IsReduceAdd()) + if(karg.IsAtomicAdd() && karg.KBatch > 1) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - if(karg.KBatch > 1) - { - return false; + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this " + << "destination type (EDataType) " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1062,19 +1101,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base template __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, - CDataType* p_c_grid, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, void* p_shared, const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, const index_t& block_m_id, const index_t& block_n_id, const index_t& num_k_block_per_scale, @@ -1084,12 +1130,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_base p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = @@ -1330,31 +1379,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_base m_thread_data_on_block_idx[I3]), ck::tensor_operation::element_wise::PassThrough{}}; - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); }, + Number{})); + + // blockwise copy which loads C from LDS, D from global, applies elementwise + // operation and stores result E to global + auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, // ThreadGroup + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, // ElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOps, Sequence<1, CShuffleMRepeatPerShuffle * MWave * MPerWmma, 1, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // DstDimAccessOrder, + 3, // SrcVectorDim, + 3, // DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors + EShuffleBlockTransferScalarPerVector, // DstScalarPerVector + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + cde_element_op}; // space filling curve for local reg & global memory // space filling curve for threadwise C in VGPR @@ -1370,7 +1446,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base MAccVgprs>>{}; // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = + constexpr auto sfc_cde_global = SpaceFillingCurve, Sequence<0, 2, 1, 3>, Sequence<1, @@ -1380,7 +1456,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS @@ -1397,20 +1473,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // make sure it's safe to read from LDS block_sync_lds(); - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); + // each block loads its C data from LDS, D from global, applies elementwise + // operation and stores result E to global + cde_shuffle_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); if constexpr(access_id < num_access - 1) { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); } }); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index ea074144b6..0235fa2d98 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -165,6 +165,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3 oob_val = oob_val & is_src_valid; + // TODO: With column-major matrices this step restricts the transferred tensor slice + // to just one element, which consequently prevents using atomic operations if the + // matrix data type is on 16 bits. if constexpr(SrcScalarPerVectors{}[i] == 1) { auto data_types = SrcDatas{}; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 429df2413f..bca68764f9 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -270,8 +270,8 @@ struct wmma_type __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { @@ -390,8 +390,8 @@ struct wmma_type __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { @@ -793,6 +793,9 @@ struct WmmaGemm "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!"); static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { + // Integer wmma operators need extra input flags to indicate if the input is signed or + // unsigned. At the moment CK supports only signed integer inputs, so these flags are + // hardcoded. if constexpr(!TransposeC) { wmma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index bf7f1b4fa4..7164f345cd 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -54,6 +54,7 @@ using MFMA = ck::tensor_layout::gemm::MFMA; using Row_Tuple = ck::Tuple; using Row_Row_Tuple = ck::Tuple; +using Row_Col_Tuple = ck::Tuple; // Conv layout // diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 030f3c2760..076474de36 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_XDL) void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); +#endif + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); +#endif -// GEMM + Add + +// GEMM + Add template struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGemmMultipleD, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - PassThrough, - PassThrough, - Add>> + ck::tensor_operation::device::DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with Add at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + +// GEMM + Add +// DeviceGemmMultipleD specialization +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>> { using DeviceOp = DeviceGemmMultipleD> op_ptrs; +#ifdef CK_USE_XDL #if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -104,10 +209,32 @@ struct DeviceOperationInstanceFactory< } #endif +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA + return op_ptrs; } }; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp index 99b2ad1315..33a01cb68b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,11 +11,13 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#if defined(CK_ENABLE_FP16) namespace ck { namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_XDL) void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( std::vector>>&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + std::vector>>&); +#endif // CK_USE_WMMA + +// GEMM + Add + FastGelu +// DeviceGemmMultipleDSplitK specialization +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddAddFastGelu>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddAddFastGelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + constexpr bool IsAllDRowLayout = is_same_v && is_same_v; + constexpr bool IsAllDFloat16 = + is_same_v && is_same_v; + + if constexpr(is_same_v && is_same_v && + is_same_v && IsAllDRowLayout && IsAllDFloat16) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA + + return op_ptrs; + } +}; // GEMM + Add + Add + FastGelu +// DeviceGemmMultipleD specialization template -struct DeviceOperationInstanceFactory, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::AddAddFastGelu>> +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddAddFastGelu>> { using DeviceOp = DeviceGemmMultipleD, EDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::AddAddFastGelu>; + PassThrough, + PassThrough, + AddAddFastGelu>; static auto GetInstances() { std::vector> op_ptrs; +#if defined(CK_USE_XDL) + constexpr bool IsAllDRowLayout = is_same_v && is_same_v; + constexpr bool IsAllDFloat16 = + is_same_v && is_same_v; + if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v && IsAllDRowLayout && IsAllDFloat16) { if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && is_same_v) { add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( op_ptrs); } else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && is_same_v) { add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( op_ptrs); } else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && is_same_v) { add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( op_ptrs); } else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && is_same_v) { add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddAddFastGelu>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -150,3 +312,4 @@ struct DeviceOperationInstanceFactory>>&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>&); +#endif // CK_USE_WMMA // GEMM + Add + FastGelu +// DeviceGemmMultipleDSplitK specialization template -struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGemmMultipleD, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - PassThrough, - PassThrough, - AddFastGelu>> +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + op_ptrs); + } + } + +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA + + return op_ptrs; + } +}; + +// GEMM + Add + FastGelu +// DeviceGemmMultipleD specialization +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>> { using DeviceOp = DeviceGemmMultipleD> op_ptrs; -#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) +#if defined(CK_USE_XDL) +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -143,7 +280,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif +#endif // CK_ENABLE_FP16 && CK_ENABLE_INT8 #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) if constexpr(is_same_v && is_same_v && @@ -156,8 +293,9 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif +#endif // CK_ENABLE_BF16 && CK_ENABLE_INT8 +#if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -186,6 +324,29 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_ENABLE_FP16 +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp index 481915d00b..7a38f43a9a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,7 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_XDL) void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( std::vector>>&); +#endif + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + std::vector>>&); +#endif + +// GEMM + Add + Multiply +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddMultiply>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddMultiply>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_XDL + +#endif + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; // GEMM + Add + Multiply template > op_ptrs; +#ifdef CK_USE_XDL if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) @@ -144,6 +285,27 @@ struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddMultiply>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index 293e14b811..4e706bb0c5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_XDL) void add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); +#endif + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); +#endif // GEMM + Add + Relu template struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGemmMultipleD, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - PassThrough, - PassThrough, - AddRelu>> + ck::tensor_operation::device::DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with AddRelu at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + +// GEMM + Add + Relu +// DeviceGemmMultipleD specialization +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>> { using DeviceOp = DeviceGemmMultipleD> op_ptrs; +#ifdef CK_USE_XDL #if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -106,10 +212,32 @@ struct DeviceOperationInstanceFactory< } #endif +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA + return op_ptrs; } }; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp index fbf45852ce..0b140fb07d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_XDL) void add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); +#endif // CK_USE_WMMA + +// GEMM + Add + Silu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddSilu>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddSilu>; + + static auto GetInstances() + { + + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // no split-k xdl implementations +#endif // CL_USE_XDL +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif // CK_ENABLE_FP16 +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif +#endif // CK_USE_WMMA + return op_ptrs; + } +}; // GEMM + Add + Silu template > op_ptrs; +#if defined(CK_USE_XDL) + #if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -105,7 +210,28 @@ struct DeviceOperationInstanceFactory< } } #endif +#endif // CL_USE_XDL +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddSilu>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 6ee88bd855..5c58a7f239 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -16,7 +16,8 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) +#if defined(CK_USE_XDL) +#if defined(CK_ENABLE_FP16) void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); -#endif -#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA) +#endif // CK_ENABLE_FP16 +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_INT8) void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector>>& instances); -#endif +#endif // CK_ENABLE_INT8 + +#if defined(CK_ENABLE_FP16) +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA + // GEMM + Bilinear template -struct DeviceOperationInstanceFactory, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::Bilinear>> +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Bilinear>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Bilinear>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with AddBilinear at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + op_ptrs); + } + } +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA + + return op_ptrs; + } +}; + +// GEMM + Bilinear +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Bilinear>> { using DeviceOp = DeviceGemmMultipleD, EDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::Bilinear>; + PassThrough, + PassThrough, + Bilinear>; static auto GetInstances() { std::vector> op_ptrs; -#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) +#if defined(CK_USE_XDL) +#if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -188,8 +326,31 @@ struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Bilinear>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } + + // Bilinear wmma i8 instances are using DeviceGemmMultipleD interface. +#if defined(CK_ENABLE_INT8) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -214,7 +375,8 @@ struct DeviceOperationInstanceFactory>>&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>&); + +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>&); +#endif // CK_USE_WMMA + +// GEMM + Add + FastGelu +// DeviceGemmMultipleDSplitK specialization +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGemmMultipleDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif // CK_USE_WMMA + + return op_ptrs; + } +}; // GEMM + FastGelu template -struct DeviceOperationInstanceFactory> +struct DeviceOperationInstanceFactory> { using DeviceOp = DeviceGemmMultipleD> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v && is_same_v && is_same_v) { @@ -127,6 +255,28 @@ struct DeviceOperationInstanceFactory; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -136,4 +286,4 @@ struct DeviceOperationInstanceFactory>>&); -#endif +#endif // CK_ENABLE_FP8 +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); +#ifdef CK_USE_WMMA_FP8 +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); +#endif // CK_USE_WMMA_FP8 +#endif // CK_USE_WMMA -// GEMM + Multiply + Add template -struct DeviceOperationInstanceFactory, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::MultiplyAdd>> +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + MultiplyAdd>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + MultiplyAdd>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleDSplitK with MultiplyAdd at the moment +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA +#ifdef CK_USE_WMMA_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA + + return op_ptrs; + } +}; + +// DeviceGemmMultipleD specialization +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + MultiplyAdd>> { using DeviceOp = DeviceGemmMultipleD, EDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::MultiplyAdd>; + PassThrough, + PassThrough, + MultiplyAdd>; static auto GetInstances() { std::vector> op_ptrs; - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) +#ifdef CK_USE_XDL + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -133,10 +279,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) +#ifdef CK_ENABLE_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -153,7 +299,29 @@ struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + MultiplyAdd>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp index 6475b801b8..5d520cd046 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP8 #ifdef CK_ENABLE_BF16 void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1( @@ -199,7 +200,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i PassThrough, PassThrough, MultiplyMultiply>>>& instances); -#endif +#endif // CK_ENABLE_BF16 #ifdef CK_ENABLE_FP16 void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances( std::vector>>& instances); -#endif -#endif +#endif // CK_ENABLE_FP16 +#endif // CK_ENABLE_FP8 #ifdef CK_ENABLE_FP16 void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1( @@ -463,7 +464,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in PassThrough, PassThrough, MultiplyMultiply>>>& instances); -#endif +#endif // CK_ENABLE_FP16 #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8)) void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances( @@ -544,7 +545,62 @@ void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_in PassThrough, MultiplyMultiply>>>& instances); -#endif +#endif // CK_ENABLE_FP16 || CK_ENABLE_INT8 +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA +void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_USE_WMMA template -struct DeviceOperationInstanceFactory, - CLayout, - ADataType, - BDataType, - DsDataType, - CDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::MultiplyMultiply>> +struct DeviceOperationInstanceFactory, + CLayout, + ADataType, + BDataType, + DsDataType, + CDataType, + PassThrough, + PassThrough, + MultiplyMultiply>> { - using DeviceOp = - DeviceGemmMultipleDSplitK, - CLayout, - ADataType, - BDataType, - DsDataType, - CDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::MultiplyMultiply>; + using DeviceOp = DeviceGemmMultipleDSplitK, + CLayout, + ADataType, + BDataType, + DsDataType, + CDataType, + PassThrough, + PassThrough, + MultiplyMultiply>; static auto GetInstances() { std::vector> op_ptrs; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP8 #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -624,7 +679,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -665,8 +720,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -691,6 +746,51 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances( + op_ptrs); + } + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances( + op_ptrs); + } + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances( + op_ptrs); + } + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 1eaaa7e6ba..56c8335d39 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -75,7 +75,7 @@ function(add_instance_library INSTANCE_NAME) endif() # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_") + if(NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -117,13 +117,13 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() - #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 + #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 and gfx1200/gfx1201 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) endif() if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) @@ -136,7 +136,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) endif() if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) @@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) + if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) message(DEBUG "Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") set(add_inst 0) endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index 298da1fbef..478e9a8ab8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -1,5 +1,8 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp + + device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..b3f862f9cd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..ec8fe54888 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index 04ae90bc5b..ab8023d1ba 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,5 +1,10 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_add_fastgelu_instance + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..8c8006cd3d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0, d1) +// elementwise(c, d0, d1) = fastgelu(c + d0 + d1) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m, n] + +template +using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..e2a99fea9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0, d1) +// elementwise(c, d0, d1) = fastgelu(c + d0 + d1) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m, n] + +template +using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..10dfce38a1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0, d1) +// elementwise(c, d0, d1) = fastgelu(c + d0 + d1) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m, n] + +template +using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..5307b44389 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0, d1) +// elementwise(c, d0, d1) = fastgelu(c + d0 + d1) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m, n] + +template +using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 45d6abce01..46f0c3b9c6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,9 +1,14 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_fastgelu_instance - device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp + + device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..cfae2c4508 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0) +// elementwise(c, d0) = fastgelu(c + d0) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n] + +template +using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000..00e06c3441 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0) +// elementwise(c, d0) = fastgelu(c + d0) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n] + +template +using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..1bc634de38 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0) +// elementwise(c, d0) = fastgelu(c + d0) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n] + +template +using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000..4a8643d553 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0) +// elementwise(c, d0) = fastgelu(c + d0) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n] + +template +using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt index d859078ca9..2e6bdca234 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt @@ -1,7 +1,11 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_multiply_instance device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..6c6f354d6f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template +using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..a56efff220 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template +using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..a92f843de4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template +using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..6b092fb2a6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template +using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 043bdab001..1bdf611907 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -1,5 +1,8 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp ) + diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..35c373a0e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..794b7f0e3e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt index e6ca64cdc1..565096dd61 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt @@ -1,5 +1,6 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_silu_instance + device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..491f25dec8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] + +template +using device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); +} + +// implement bf16 with same parameters to avoid duplication +void add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index 61892e708c..39e83495d4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -4,6 +4,10 @@ add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..142c89c80b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n]) +template +using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000..cbf0fe6563 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n]) +template +using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..c8a7a66a93 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n]) +template +using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000..57ea32b083 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n]) +template +using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp index 6a23b70321..a948a59c00 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp @@ -45,7 +45,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + // M/N/K padding // N % 16 == 0 && K % 16 == 0 //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -55,7 +55,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + // M/N/K padding // N % 8 == 0 && K % 8 == 0 //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -65,7 +65,6 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding // N % 8 == 0 && K % 8 == 0 @@ -76,7 +75,6 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>, // M/N/K padding // N % 1 == 0 && K % 8 == 0 @@ -86,8 +84,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index 2f45401ec6..f3273fb8ed 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,5 +1,10 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_fastgelu_instance + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..6bb4f4a0e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise(a * b) +// elementwise(c) = fastgelu(c) +// output: e[m, n] +// input: a[m, k], b[n, k] + +template +using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..5e0a9da5a8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise(a * b) +// elementwise(c) = fastgelu(c) +// output: e[m, n] +// input: a[m, k], b[n, k] + +template +using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..cba209b408 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise(a * b) +// elementwise(c) = fastgelu(c) +// output: e[m, n] +// input: a[m, k], b[n, k] + +template +using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..27d34bfe63 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +// e = elementwise(a * b) +// elementwise(c) = fastgelu(c) +// output: e[m, n] +// input: a[m, k], b[n, k] + +template +using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index aba9806a74..3a27e43dd6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,7 +1,11 @@ -# ONLY XDL_KERNELS -set(GEMM_MULTIPLY_ADD_INSTANCES) -list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp - device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp - device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp - device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp) -add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES}) +# ONLY XDL_AND_WMMA_KERNELS +add_instance_library(device_gemm_multiply_add_instance + device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp + device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp + device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..6ab8f44026 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +using device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..30d40d7002 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +using device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..137e67df25 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +using device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..933af4c40d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +using device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt index 6336833c3a..0e52eac0bf 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt @@ -1,4 +1,4 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_MULTIPLY_MULTIPLY_INSTANCES) list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES @@ -38,6 +38,11 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + + device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp + device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp + device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_mk_nk_mn.cpp + device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_mk_nk_mn.cpp ) set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_mk_nk_mn.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..bafbe66e4b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_mk_nk_mn.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V3 = BlockGemmPipelineVersion::v3; +static constexpr auto V1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8> + // clang-format on + >; + +void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_mk_nk_mn.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..fc96eee74b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_mk_nk_mn.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V3 = BlockGemmPipelineVersion::v3; +static constexpr auto V1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8> + // clang-format on + >; + +void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..2397c1a760 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V3 = BlockGemmPipelineVersion::v3; +static constexpr auto V1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8> + // clang-format on + >; + +void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..5cc13884dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto V3 = BlockGemmPipelineVersion::v3; +static constexpr auto V1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8> + // clang-format on + >; + +void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp index 6f6d881c1e..9e4d30142b 100644 --- a/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp @@ -33,7 +33,7 @@ template bool profile_gemm_add_fastgelu_impl(int do_verification, int init_method, - bool /*do_log*/, + bool do_log, bool time_kernel, int M, int N, @@ -213,6 +213,17 @@ bool profile_gemm_add_fastgelu_impl(int do_verification, { e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + if(do_log) + { + LogRangeAsType( + std::cout << "e_m_n_device_result: ", e_m_n_device_result.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "e_m_n_host_result: ", e_m_n_host_result.mData, ",") + << std::endl; + } + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); } } diff --git a/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp b/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp index 25871dfb2e..fcb546fe96 100644 --- a/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp @@ -35,7 +35,7 @@ template bool profile_gemm_add_multiply_impl(int do_verification, int init_method, - bool /*do_log*/, + bool do_log, bool time_kernel, int M, int N, @@ -223,6 +223,17 @@ bool profile_gemm_add_multiply_impl(int do_verification, { e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + if(do_log) + { + LogRangeAsType( + std::cout << "e_m_n_device_result: ", e_m_n_device_result.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "e_m_n_host_result: ", e_m_n_host_result.mData, ",") + << std::endl; + } + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); } } diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100755 new mode 100644 index 640b192baf..77b5067eae --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 4700a34e9d..817ebb47c3 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -44,22 +44,16 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_add.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm.cpp) list(APPEND PROFILER_OPS profile_gemm_streamk.cpp) - list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp) list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp) list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) endif() - list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") - list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) @@ -70,6 +64,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") endif() list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) + list(APPEND PROFILER_OPS profile_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) @@ -85,11 +80,14 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR - (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))) + (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")) list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp) endif() +if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])") + list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp) +endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) @@ -98,6 +96,15 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12 list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) + list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) + list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp) + list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp) + list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) + endif() endif() if(DL_KERNELS) @@ -153,24 +160,18 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_INSTANCES device_gemm_add_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance) list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance) list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance) list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance) endif() list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance) - list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") - list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance) list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance) list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance) @@ -185,6 +186,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance) list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance) list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance) @@ -206,9 +208,12 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR - (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))) + (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" )) list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance) endif() +if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])") + list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance) +endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) @@ -219,6 +224,16 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) + list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) + endif() endif() if(DL_KERNELS) diff --git a/profiler/src/profile_gemm_add_multiply.cpp b/profiler/src/profile_gemm_add_multiply.cpp index 560467c264..f8ec7abb66 100644 --- a/profiler/src/profile_gemm_add_multiply.cpp +++ b/profiler/src/profile_gemm_add_multiply.cpp @@ -35,10 +35,10 @@ int profile_gemm_add_multiply(int argc, char* argv[]) // clang-format off printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); - printf("arg3: matrix layout (0: E[m, n] = AddMultiply((A[m, k] * B[k, n] + D0[m, n]) x D1[m, n]);\n"); - printf(" 1: E[m, n] = AddMultiply((A[m, k] * B[k, n] + D0[m, n]) x D1[m, n]);\n"); - printf(" 2: E[m, n] = AddMultiply((A[m, k] * B[k, n] + D0[m, n]) x D1[m, n]);\n"); - printf(" 3: E[m, n] = AddMultiply((A[m, k] * B[k, n] + D0[m, n]) x D1[m, n]))\n"); + printf("arg3: matrix layout (0: E[m, n] = (A[m, k] * B[k, n] + D0[m, n]) x D1[m, n];\n"); + printf(" 1: E[m, n] = (A[m, k] * B[n, k] + D0[m, n]) x D1[m, n];\n"); + printf(" 2: E[m, n] = (A[k, m] * B[k, n] + D0[m, n]) x D1[m, n];\n"); + printf(" 3: E[m, n] = (A[k, m] * B[n, k] + D0[m, n]) x D1[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n"); diff --git a/profiler/src/profile_gemm_add_silu.cpp b/profiler/src/profile_gemm_add_silu.cpp index daaaef0fa2..a35c0a4092 100644 --- a/profiler/src/profile_gemm_add_silu.cpp +++ b/profiler/src/profile_gemm_add_silu.cpp @@ -27,15 +27,17 @@ int profile_gemm_add_silu(int argc, char* argv[]) enum struct MatrixDataType { - F16_INT8_F16_F16, // 0 - BF16_INT8_BF16_BF16, // 1 + F16_INT8_F16_F16 = 0, + BF16_INT8_BF16_BF16 = 1, + F16_F16_F16_F16 = 2, + BF16_BF16_BF16_BF16 = 3 }; if(argc != 15) { // clang-format off printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); - printf("arg2: data type (0: f16&i8 1: bf16&i8)\n"); + printf("arg2: data type (0: f16&i8 1: bf16&i8 2: f16&f16 3: bf16&bf16)\n"); printf("arg3: matrix layout (0: E[m, n] = ReLU(A[m, k] * B[k, n] + D0[m, n]);\n"); printf(" 1: E[m, n] = ReLU(A[m, k] * B[n, k] + D0[m, n]);\n"); printf(" 2: E[m, n] = ReLU(A[k, m] * B[k, n] + D0[m, n]);\n"); @@ -128,6 +130,14 @@ int profile_gemm_add_silu(int argc, char* argv[]) { return profile(BF16{}, INT8{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{}); } + else if(data_type == MatrixDataType::BF16_BF16_BF16_BF16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(BF16{}, BF16{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_gemm_multiply_multiply.cpp b/profiler/src/profile_gemm_multiply_multiply.cpp index 42192b5985..58984b324b 100644 --- a/profiler/src/profile_gemm_multiply_multiply.cpp +++ b/profiler/src/profile_gemm_multiply_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -92,9 +92,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; - using F8 = ck::f8_t; - using I8 = int8_t; - using I32 = int; +#if defined(CK_USE_XDL) || defined(CK_USE_WMMA_FP8) + using F8 = ck::f8_t; +#endif +#ifdef CK_ENABLE_INT8 + using I8 = int8_t; + using I32 = int; +#endif using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -163,32 +167,31 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) return pass ? 0 : 1; }; +#if defined(CK_USE_XDL) || defined(CK_USE_WMMA_FP8) if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile( F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile( F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) +#endif // CK_ENABLE_FP8 + if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile( I8{}, I8{}, I8{}, I32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile( I8{}, I8{}, I8{}, I32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{}); } - else - { - std::cout << "this data_type & layout is not implemented" << std::endl; - return 1; - } + std::cout << "this data_type & layout is not implemented" << std::endl; + return 1; } REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_multiply_multiply); diff --git a/profiler/src/profile_gemm_universal_streamk.cpp b/profiler/src/profile_gemm_universal_streamk.cpp old mode 100755 new mode 100644 index 40ae0d70f5..3fd61db138 --- a/profiler/src/profile_gemm_universal_streamk.cpp +++ b/profiler/src/profile_gemm_universal_streamk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc old mode 100755 new mode 100644 diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index ab4c781847..fe0a08c0c9 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,19 +1,71 @@ -add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp) +# Implements test instances for MultipleD with xdl and wmma support. + +add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance) +endif() + +add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +endif() + +add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +endif() + +add_gtest_executable(test_gemm_add_silu_wmma test_gemm_add_silu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_silu_wmma PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +endif() + +add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +endif() + +add_gtest_executable(test_gemm_fastgelu_wmma test_gemm_fastgelu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_fastgelu_wmma PRIVATE utility device_gemm_fastgelu_instance) +endif() + +add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) + target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance) endif() -add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp) +add_gtest_executable(test_gemm_add_wmma test_gemm_add_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) + target_link_libraries(test_gemm_add_wmma PRIVATE utility device_gemm_add_instance) endif() -add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp) +add_gtest_executable(test_gemm_add_add_fastgelu_wmma test_gemm_add_add_fastgelu_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) + target_link_libraries(test_gemm_add_add_fastgelu_wmma PRIVATE utility device_gemm_add_add_fastgelu_instance) endif() -add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp) +add_gtest_executable(test_gemm_multiply_multiply_wmma test_gemm_multiply_multiply_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) + target_link_libraries(test_gemm_multiply_multiply_wmma PRIVATE utility device_gemm_multiply_multiply_instance) endif() + +add_gtest_executable(test_gemm_add_multiply_wmma test_gemm_add_multiply_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_multiply_wmma PRIVATE utility device_gemm_add_multiply_instance) +endif() + +add_gtest_executable(test_gemm_multiply_add_wmma test_gemm_multiply_add_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_multiply_add_wmma PRIVATE utility device_gemm_multiply_add_instance) +endif() + +add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance) +endif() + +add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance) +endif() \ No newline at end of file diff --git a/test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp new file mode 100644 index 0000000000..25da138a04 --- /dev/null +++ b/test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_add_fastgelu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddAddFastgelu : public TestGemmD0D1Common +{ + using ProfileCall = typename TestGemmD0D1Common::ProfileCall; + + public: + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_add_fastgelu_impl< + typename TestGemmD0D1Common::ADataType, + typename TestGemmD0D1Common::BDataType, + typename TestGemmD0D1Common::AccDataType, + typename TestGemmD0D1Common::D0DataType, + typename TestGemmD0D1Common::D1DataType, + typename TestGemmD0D1Common::EDataType, + typename TestGemmD0D1Common::ALayout, + typename TestGemmD0D1Common::BLayout, + typename TestGemmD0D1Common::D0Layout, + typename TestGemmD0D1Common::D1Layout, + typename TestGemmD0D1Common::ELayout>; + } +}; + +using KernelTypes = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddAddFastgelu, KernelTypes); +TYPED_TEST(TestGemmAddAddFastgelu, Test_FP16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp new file mode 100644 index 0000000000..df70a0cc99 --- /dev/null +++ b/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_fastgelu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddFastgelu : public TestGemmD0Common +{ + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_fastgelu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes); +TYPED_TEST(TestGemmAddFastgelu, Test_FP16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index 1b12ab7528..0e034f46b5 100644 --- a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -1,37 +1,29 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddFastgelu : public TestGemmAdd +class TestGemmAddFastgelu : public TestGemmD0Common { - private: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using AccDataType = std::tuple_element_t<2, Tuple>; - using D0DataType = std::tuple_element_t<3, Tuple>; - using EDataType = std::tuple_element_t<4, Tuple>; - using ALayout = std::tuple_element_t<5, Tuple>; - using BLayout = std::tuple_element_t<6, Tuple>; - using D0Layout = std::tuple_element_t<7, Tuple>; - using ELayout = std::tuple_element_t<8, Tuple>; + using ProfileCall = typename TestGemmD0Common::ProfileCall; - constexpr static auto ProfileGemmAddFastgeluImpl = - ck::profiler::profile_gemm_add_fastgelu_impl; - - decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; } + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_fastgelu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } }; using KernelTypes = ::testing::Types, diff --git a/test/gemm_add/test_gemm_add_multiply_wmma.cpp b/test/gemm_add/test_gemm_add_multiply_wmma.cpp new file mode 100644 index 0000000000..be4a99d69f --- /dev/null +++ b/test/gemm_add/test_gemm_add_multiply_wmma.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "test_gemm_common.hpp" +#include "profiler/profile_gemm_add_multiply_impl.hpp" + +template +class TestGemmAddMultiply : public TestGemmD0D1Common +{ + using ProfileCall = typename TestGemmD0D1Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_multiply_impl< + typename TestGemmD0D1Common::ADataType, + typename TestGemmD0D1Common::BDataType, + typename TestGemmD0D1Common::AccDataType, + typename TestGemmD0D1Common::D0DataType, + typename TestGemmD0D1Common::D1DataType, + typename TestGemmD0D1Common::EDataType, + typename TestGemmD0D1Common::ALayout, + typename TestGemmD0D1Common::BLayout, + typename TestGemmD0D1Common::D0Layout, + typename TestGemmD0D1Common::D1Layout, + typename TestGemmD0D1Common::ELayout>; + } +}; + +using KernelTypes = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddMultiply, KernelTypes); +// Due to F16 shuffle data type tests has to run with limited K size. Change instances to FP32? +TYPED_TEST(TestGemmAddMultiply, Test) { this->Run({{16, 32, 64}, {2048, 1024, 256}}); } diff --git a/test/gemm_add/test_gemm_add_relu_wmma.cpp b/test/gemm_add/test_gemm_add_relu_wmma.cpp new file mode 100644 index 0000000000..76c66a11b1 --- /dev/null +++ b/test/gemm_add/test_gemm_add_relu_wmma.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_relu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddRelu : public TestGemmD0Common +{ + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_relu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes); +TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_relu_xdl.cpp b/test/gemm_add/test_gemm_add_relu_xdl.cpp index e8b769b1cb..4b445e8e41 100644 --- a/test/gemm_add/test_gemm_add_relu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_relu_xdl.cpp @@ -1,37 +1,29 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddRelu : public TestGemmAdd +class TestGemmAddRelu : public TestGemmD0Common { - private: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using AccDataType = std::tuple_element_t<2, Tuple>; - using D0DataType = std::tuple_element_t<3, Tuple>; - using EDataType = std::tuple_element_t<4, Tuple>; - using ALayout = std::tuple_element_t<5, Tuple>; - using BLayout = std::tuple_element_t<6, Tuple>; - using D0Layout = std::tuple_element_t<7, Tuple>; - using ELayout = std::tuple_element_t<8, Tuple>; + using ProfileCall = typename TestGemmD0Common::ProfileCall; - constexpr static auto ProfileGemmAddReluImpl = - ck::profiler::profile_gemm_add_relu_impl; - - decltype(ProfileGemmAddReluImpl) GetImpl() override { return ProfileGemmAddReluImpl; } + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_relu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } }; using KernelTypes = ::testing::Types, diff --git a/test/gemm_add/test_gemm_add_silu_wmma.cpp b/test/gemm_add/test_gemm_add_silu_wmma.cpp new file mode 100644 index 0000000000..7afa68dfe9 --- /dev/null +++ b/test/gemm_add/test_gemm_add_silu_wmma.cpp @@ -0,0 +1,34 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_silu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddSilu : public TestGemmD0Common +{ + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_silu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes); +TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_silu_xdl.cpp b/test/gemm_add/test_gemm_add_silu_xdl.cpp index 75fa59a8e7..6bd0ee422d 100644 --- a/test/gemm_add/test_gemm_add_silu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_silu_xdl.cpp @@ -1,37 +1,29 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddSilu : public TestGemmAdd +class TestGemmAddSilu : public TestGemmD0Common { - private: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using AccDataType = std::tuple_element_t<2, Tuple>; - using D0DataType = std::tuple_element_t<3, Tuple>; - using EDataType = std::tuple_element_t<4, Tuple>; - using ALayout = std::tuple_element_t<5, Tuple>; - using BLayout = std::tuple_element_t<6, Tuple>; - using D0Layout = std::tuple_element_t<7, Tuple>; - using ELayout = std::tuple_element_t<8, Tuple>; + using ProfileCall = typename TestGemmD0Common::ProfileCall; - constexpr static auto ProfileGemmAddSiluImpl = - ck::profiler::profile_gemm_add_silu_impl; - - decltype(ProfileGemmAddSiluImpl) GetImpl() override { return ProfileGemmAddSiluImpl; } + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_silu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } }; using KernelTypes = ::testing::Types, diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp new file mode 100644 index 0000000000..ae08d50fcc --- /dev/null +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAdd : public TestGemmD0Common +{ + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_impl::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); +TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_xdl.cpp b/test/gemm_add/test_gemm_add_xdl.cpp new file mode 100644 index 0000000000..6696c1ccf6 --- /dev/null +++ b/test/gemm_add/test_gemm_add_xdl.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAdd : public TestGemmD0Common +{ + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_impl::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); +TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_xdl.hpp b/test/gemm_add/test_gemm_add_xdl.hpp deleted file mode 100644 index 11d3d1c10a..0000000000 --- a/test/gemm_add/test_gemm_add_xdl.hpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "gtest/gtest.h" -#include "ck/ck.hpp" -#include "profiler/profile_gemm_add_impl.hpp" - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using I8 = int8_t; -using BF16 = ck::bhalf_t; -using F16 = ck::half_t; -using F32 = float; - -template -class TestGemmAdd : public ::testing::Test -{ - protected: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using AccDataType = std::tuple_element_t<2, Tuple>; - using D0DataType = std::tuple_element_t<3, Tuple>; - using EDataType = std::tuple_element_t<4, Tuple>; - using ALayout = std::tuple_element_t<5, Tuple>; - using BLayout = std::tuple_element_t<6, Tuple>; - using D0Layout = std::tuple_element_t<7, Tuple>; - using ELayout = std::tuple_element_t<8, Tuple>; - - constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; - - virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; } - - void Run() - { - std::vector> lengths = { - {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; - - bool all_success = true; - - for(auto length : lengths) - { - int M = length[0]; - int N = length[1]; - int K = length[2]; - int StrideA = ck::is_same_v ? K : M; - int StrideB = ck::is_same_v ? N : K; - int StrideD0 = ck::is_same_v ? N : M; - int StrideE = ck::is_same_v ? N : M; - - all_success = - all_success & - GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); - } - - EXPECT_TRUE(all_success); - } -}; - -using KernelTypes = ::testing::Types, - std::tuple>; - -TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); -TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_bilinear_wmma.cpp b/test/gemm_add/test_gemm_bilinear_wmma.cpp new file mode 100644 index 0000000000..dfa8ac7121 --- /dev/null +++ b/test/gemm_add/test_gemm_bilinear_wmma.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_bilinear_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmBilinear : public ::testing::Test +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmBilinearImpl = + ck::profiler::profile_gemm_bilinear_impl; + + public: + void Run(TestMatrixSizes const& lengths) + { + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + ProfileGemmBilinearImpl( + 1, 1, false, true, M, N, K, StrideA, StrideB, StrideD0, StrideE, 1.F, 1.F); + } + + EXPECT_TRUE(all_success); + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmBilinear, KernelTypes); +TYPED_TEST(TestGemmBilinear, Test) { this->Run(DefaultTestMatrixSizes); } diff --git a/test/gemm_add/test_gemm_common.hpp b/test/gemm_add/test_gemm_common.hpp new file mode 100644 index 0000000000..9ab6c335e9 --- /dev/null +++ b/test/gemm_add/test_gemm_common.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using I32 = int32_t; +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using F8 = ck::f8_t; + +// M, N, K +using TestMatrixSizes = std::vector>; + +static const TestMatrixSizes DefaultTestMatrixSizes = { + {16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}}; + +template +class TestGemmCommon : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using BLayout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int); + + virtual ProfileCall GetImpl() = 0; + + void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes) + { + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; + +template +class TestGemmD0Common : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, int); + + virtual ProfileCall GetImpl() = 0; + + void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes) + { + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideD0, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; + +template +class TestGemmD0D1Common : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using D1DataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + using ALayout = std::tuple_element_t<6, Tuple>; + using BLayout = std::tuple_element_t<7, Tuple>; + using D0Layout = std::tuple_element_t<8, Tuple>; + using D1Layout = std::tuple_element_t<9, Tuple>; + using ELayout = std::tuple_element_t<10, Tuple>; + + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, int, int); + + virtual ProfileCall GetImpl() = 0; + + void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes) + { + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideD1 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + GetImpl()( + 1, 1, false, true, M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; diff --git a/test/gemm_add/test_gemm_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_fastgelu_wmma.cpp new file mode 100644 index 0000000000..d8dd218ec6 --- /dev/null +++ b/test/gemm_add/test_gemm_fastgelu_wmma.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_fastgelu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmFastgelu : public TestGemmCommon +{ + using ProfileCall = typename TestGemmCommon::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_fastgelu_impl::ADataType, + typename TestGemmCommon::BDataType, + typename TestGemmCommon::AccDataType, + typename TestGemmCommon::EDataType, + typename TestGemmCommon::ALayout, + typename TestGemmCommon::BLayout, + typename TestGemmCommon::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmFastgelu, KernelTypes); +TYPED_TEST(TestGemmFastgelu, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_multiply_add_wmma.cpp b/test/gemm_add/test_gemm_multiply_add_wmma.cpp new file mode 100644 index 0000000000..2531464f72 --- /dev/null +++ b/test/gemm_add/test_gemm_multiply_add_wmma.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_gemm_common.hpp" +#include "profiler/profile_gemm_multiply_add_impl.hpp" + +template +class TestGemmMultiplyAdd : public TestGemmD0D1Common +{ + using ProfileCall = typename TestGemmD0D1Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_multiply_add_impl< + typename TestGemmD0D1Common::ADataType, + typename TestGemmD0D1Common::BDataType, + typename TestGemmD0D1Common::AccDataType, + typename TestGemmD0D1Common::D0DataType, + typename TestGemmD0D1Common::D1DataType, + typename TestGemmD0D1Common::EDataType, + typename TestGemmD0D1Common::ALayout, + typename TestGemmD0D1Common::BLayout, + typename TestGemmD0D1Common::D0Layout, + typename TestGemmD0D1Common::D1Layout, + typename TestGemmD0D1Common::ELayout>; + } +}; + +using KernelTypes = ::testing::Types< +#ifdef CK_USE_WMMA_FP8 + std::tuple, + std::tuple, +#endif + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmMultiplyAdd, KernelTypes); +// Due to F16 shuffle data type tests has to run with limited K size. Change instances to FP32? +TYPED_TEST(TestGemmMultiplyAdd, Test) { this->Run({{16, 32, 64}, {2048, 1024, 256}}); } diff --git a/test/gemm_add/test_gemm_multiply_multiply_wmma.cpp b/test/gemm_add/test_gemm_multiply_multiply_wmma.cpp new file mode 100644 index 0000000000..74e900c43f --- /dev/null +++ b/test/gemm_add/test_gemm_multiply_multiply_wmma.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multiply_multiply_impl.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using I32 = int32_t; +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +class TestGemmMultiplyMultiply : public ::testing::Test +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using D1DataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + using ALayout = std::tuple_element_t<6, Tuple>; + using BLayout = std::tuple_element_t<7, Tuple>; + using D0Layout = std::tuple_element_t<8, Tuple>; + using D1Layout = std::tuple_element_t<9, Tuple>; + using ELayout = std::tuple_element_t<10, Tuple>; + + constexpr static auto ProfileGemmMultiplyMultiplyImpl = + ck::profiler::profile_gemm_multiply_multiply_impl; + + public: + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideD1 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = all_success & ProfileGemmMultiplyMultiplyImpl(1, + 1, + false, + true, + M, + N, + K, + StrideA, + StrideB, + StrideD0, + StrideD1, + StrideE, + 1, + 1, + 1, + 0); + } + + EXPECT_TRUE(all_success); + } +}; + +using KernelTypes = ::testing::Types< +#ifdef CK_USE_WMMA_FP8 + std::tuple, + std::tuple, +#endif + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmMultiplyMultiply, KernelTypes); +TYPED_TEST(TestGemmMultiplyMultiply, Test) { this->Run(); }