Skip to content

Wmma support for multiple Ds based GEMMs #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 212 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
212 commits
Select commit Hold shift + click to select a range
5bdc993
Fixed cmake errors related to gemm_bilinear. Previously, if the abov…
ancahamuraru Apr 30, 2025
1217661
Fixed cmake build errors related to test_fp8
ancahamuraru Apr 30, 2025
df929f0
Updates to support mixed precision
ancahamuraru May 1, 2025
03c0446
Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip
ancahamuraru May 1, 2025
977d8a6
Added support for F8xF16xF16 to gemm_wmma_universal
ancahamuraru May 1, 2025
636fbd5
Added support for F16xF8xF16 to gemm_wmma_universal
ancahamuraru May 5, 2025
c5d99d2
Added support for BF16xI4xBF16 to gemm_wmma_universal
ancahamuraru May 5, 2025
55f1602
Added support for F16xI4xF16 to gemm_wmma_universal
ancahamuraru May 5, 2025
d892b5a
Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead…
ancahamuraru May 8, 2025
97e0249
Added missing test class for FP16_KM_NK
ancahamuraru May 8, 2025
744262d
Pre-commit hooks fixes
ancahamuraru May 8, 2025
dd47f39
Added padding instances for f16xf16xf16
ancahamuraru May 14, 2025
8b52033
Fixed cmake errors related to gemm_bilinear. Previously, if the abov…
ancahamuraru May 14, 2025
527c844
Merge branch 'cherry-pick-5bdc993d' into '33-wip'
ancahamuraru May 14, 2025
7107bcc
Fixed cmake build errors related to test_fp8
ancahamuraru May 14, 2025
647024d
Ammending changes for adding support for padding instances for f16xf1…
ancahamuraru May 14, 2025
ab28ac0
Merge branch '33-wip' of projects.streamhpc.com:amd/ai/composable_ker…
ancahamuraru May 14, 2025
e285563
Fixes for padding instances for f16xf16xf16
ancahamuraru May 14, 2025
0482b83
Added padding instances for bf16xbf16, f8xf8
ancahamuraru May 14, 2025
a3d99c6
Merge branch '32-add-the-remaining-combination-of-data-types-mixed-pr…
ancahamuraru May 15, 2025
537d19c
Added packed instances for bf16xi4xbf16
ancahamuraru May 15, 2025
e97a7f2
Added padding instances for f8xf16xf16
ancahamuraru May 15, 2025
bbda71f
Added padding instances for f16xf8xf16, f16xi4xf16
ancahamuraru May 16, 2025
a0a2bf2
Fixed typos for bf16xbf16xbf16 padding instances
ancahamuraru May 19, 2025
7975c9c
Fixed typos for padded instances
ancahamuraru May 19, 2025
dc26ee3
Added tests for fp16, KM_KN and KM_NK
ancahamuraru May 20, 2025
a08ca63
Padding not supported for when BDataType is pk_i4_t. Added fix for co…
ancahamuraru May 20, 2025
0a5e6d4
Fixed typos
ancahamuraru May 20, 2025
b350bd2
Updated the set of tests for FP16
ancahamuraru May 20, 2025
ae21582
Updated the set of tests for FP16
ancahamuraru May 20, 2025
185ea0f
Fix typo
ancahamuraru May 20, 2025
b1a9a27
Merge branch '33-wip' of projects.streamhpc.com:amd/ai/composable_ker…
ancahamuraru May 20, 2025
15bfa00
Moved f16xi4 test under the correct data layout group
ancahamuraru May 20, 2025
621012c
example for gemm_universal_bf16
ApoorvaKalyani May 7, 2025
b35a195
Adding examples for gemm_wmma instances
ApoorvaKalyani May 7, 2025
8f8e631
Added the missing parameters
ApoorvaKalyani May 7, 2025
840b79d
Fixed review comments and added executable to cmakeLists
ApoorvaKalyani May 8, 2025
4b5a9ac
Fixing clang format
ApoorvaKalyani May 8, 2025
9cc5702
Fixing build erros
ApoorvaKalyani May 8, 2025
b0aa933
Fixed compilation failure.
ApoorvaKalyani May 12, 2025
c016164
Modified some code as per gemm_universal_examples
ApoorvaKalyani May 13, 2025
9d8f1e4
Fixed the gemm specialization error
ApoorvaKalyani May 13, 2025
501d957
Fixed the build errors.
ApoorvaKalyani May 15, 2025
cc818b4
Fix strides of a/b_thread_desc
ex-rzr Apr 28, 2025
af9e9ed
Load in M/NRepeat dims with thread copy's slice instead of a loop
ex-rzr Apr 28, 2025
ede7126
Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation
ex-rzr Apr 28, 2025
c414097
Implement Intrawave and Interwave variants of pipeline v1
ex-rzr Apr 30, 2025
c94c3b4
Add instances for Interwave and Intrawave v1
ex-rzr May 16, 2025
04d3fc7
Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0
ex-rzr May 16, 2025
17bc0fa
Remove instances that are too slow (mostly because of register spilling)
ex-rzr May 19, 2025
342bb57
Add a workaround for fp8/bf8->f32 packed conversion issue
ex-rzr May 20, 2025
5082a9c
Add instances for Interwave and Intrawave v1
ex-rzr May 20, 2025
c7d39a0
Enable profiling of mixed precision with f8 and int4 on WMMA
ex-rzr May 20, 2025
8b5d340
Fix segfault in profiler when B is pk_i4_t
ex-rzr May 21, 2025
b1f50b5
Remove instances that are too slow (mostly because of register spilling)
ex-rzr May 21, 2025
02bf56a
Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations
ex-rzr May 21, 2025
dd7ac95
Add test case for bf16_i4
ex-rzr May 21, 2025
eac7d35
Add missing Regular tests
ex-rzr May 21, 2025
05ad214
Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS
ex-rzr May 22, 2025
83b1419
Fix a bug that fp16_i4 validation passes only with PermuteB
ex-rzr May 22, 2025
9e70603
Use PermuteB with f16_i4 in most instances (as xdl)
ex-rzr May 22, 2025
c143bf3
Fix cache flushing for pk_i4
ex-rzr May 22, 2025
668914c
Add mixed precision examples
ex-rzr May 22, 2025
2679c0a
Disable all tests and instances with f8 on gfx11
ex-rzr May 23, 2025
a6ea604
Add FP16 KM_NK and KM_KN test suites for XDL
ex-rzr May 23, 2025
da5f962
Support multiple D in GridwiseGemm_wmma_cshuffle_v3
ex-rzr May 29, 2025
99fc05e
Use ThreadGroupTensorSliceTransfer_v7r3
ex-rzr May 29, 2025
a038ba3
Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support
ex-rzr May 29, 2025
5151206
Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for…
ex-rzr May 29, 2025
f13b913
Implement DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
db51d8a
Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
22935c8
Prepare gemma_add tests for adding wmma
ex-rzr Jun 2, 2025
25f7204
Add gemm_add_fastgelu instances and test
ex-rzr Jun 2, 2025
959defb
Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with…
ex-rzr Jun 2, 2025
b8e45c7
removed unnecessary ck parts from compilation
May 30, 2025
538fa87
initial gemm_add_multiply instance implementations
May 30, 2025
8727762
fixed profiler help message for gemm_add_multiply
May 30, 2025
63513c3
improved multiply_add profiler layout help
May 30, 2025
07f75d9
fixed template arguments for test instances
Jun 2, 2025
75550ff
added test for gemm_add_multiply
Jun 3, 2025
ed047d0
Support multiple D in GridwiseGemm_wmma_cshuffle_v3
ex-rzr May 29, 2025
deebe1e
Use ThreadGroupTensorSliceTransfer_v7r3
ex-rzr May 29, 2025
7dff5fe
Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support
ex-rzr May 29, 2025
89ac60d
Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for…
ex-rzr May 29, 2025
137efa7
Implement DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
e36a176
Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
bcf93e2
Prepare gemma_add tests for adding wmma
ex-rzr Jun 2, 2025
381c02d
Add gemm_add_fastgelu instances and test
ex-rzr Jun 2, 2025
9912e5f
Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with…
ex-rzr Jun 2, 2025
881bc3f
Merge branch '52-implement-multipled-in-gemm-universal' into 'feature…
ex-rzr Jun 4, 2025
4e07085
switched to splitK interface
Jun 4, 2025
8658ca6
log print added to splitk benchmarks
Jun 5, 2025
a902c57
revert main cmake comments
Jun 5, 2025
32e78b6
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jun 5, 2025
0228eca
newline change reverted
Jun 6, 2025
ea9805b
added add_fastgelu instances
Jun 10, 2025
aeca8ef
revert unintended change in xdl add_fastgelu
Jun 11, 2025
4c8ea95
created gemm_add_add_fastgelu instances
Jun 11, 2025
264e1b2
created fastegelu instances
Jun 11, 2025
b4d3e41
added tests for all splitk fastgelus
Jun 12, 2025
0696f99
Added tests.
ApoorvaKalyani Jun 13, 2025
a529e3e
multiply_add instances created
Jun 13, 2025
27d86a3
updates to add_multiply splitk instances
Jun 13, 2025
61b6e9a
splitk xdl test fixes
Jun 13, 2025
ac60286
added wmma multiply_multiply instances
Jun 17, 2025
7424b4a
fixed ONLY_XDL_AND_WMMA_KERNELS tag
Jun 17, 2025
30d65b9
Added gemm_add examples for wmma v1 and v3
ApoorvaKalyani Jun 18, 2025
90c9b09
Merge branch '61-add-examples-for-bf16-and-fp16-instances-of-gemm_add…
ApoorvaKalyani Jun 18, 2025
cd0172b
fixed / workarounded i8 instances
Jun 18, 2025
055bc02
Merge branch '10-implement-device_gemm_add_fastgelu-for-rdna4' into '…
Jun 18, 2025
40ce862
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 6…
Jun 18, 2025
c2077ca
Modified the v3 code to added one fp16 bxdl instance.
ApoorvaKalyani Jun 18, 2025
57c3fd9
added bf16 xdl instance.
ApoorvaKalyani Jun 18, 2025
b42b6b6
adding gemm_add wmma_cshuffle and other support
ApoorvaKalyani May 26, 2025
113ea09
add instances into camkelists
ApoorvaKalyani May 26, 2025
b129e73
This is work in progress, edited the template parameters in order to …
ApoorvaKalyani May 26, 2025
455275d
temp work saved, changed the BDataType to f16 or bf16 since wmma curr…
ApoorvaKalyani May 26, 2025
1fda499
added datatype and use clang-format-12
ApoorvaKalyani May 26, 2025
1519eaa
Fixing build errors
ApoorvaKalyani May 28, 2025
32b9500
Added instances for v3
ApoorvaKalyani Jun 11, 2025
7da9f64
Adding instances and executables
ApoorvaKalyani Jun 11, 2025
0cce81c
Code update of template parameters modified.
ApoorvaKalyani Jun 12, 2025
6df313f
Renamed file.
ApoorvaKalyani Jun 12, 2025
06d44f1
Added tests.
ApoorvaKalyani Jun 13, 2025
10d648a
resolved error tests.
ApoorvaKalyani Jun 13, 2025
ef781db
Fixing build errors
ApoorvaKalyani Jun 13, 2025
bd49ec0
Updated comments
ApoorvaKalyani Jun 13, 2025
3301ef5
removed the changes as per the MR review comment.
ApoorvaKalyani Jun 19, 2025
38d0027
Updated tests.
ApoorvaKalyani Jun 19, 2025
5e45427
fp8 instances - not tested
Jun 19, 2025
c8b3f3d
Restored the Cmake file that was reverted by mistake during rebase.
ApoorvaKalyani Jun 19, 2025
a8dec7a
fixed wmma_op test
Jun 19, 2025
78c2ee2
Updated comments.
ApoorvaKalyani Jun 19, 2025
ed5ac21
Updated the template parameter description
ApoorvaKalyani Jun 19, 2025
1c01ff6
fixed rdna4 instances
Jun 23, 2025
fb4c1b5
fixed back compatibility on gfx11
Jun 23, 2025
d7b4d51
cleanups
Jun 24, 2025
94f543c
fix ckProfiler
Jun 24, 2025
8b694c3
one more cmake fix
Jun 24, 2025
3c3136b
added fp8 instances
Jun 24, 2025
71d65d4
Updated tests to ad BF16 instances as per review comment
ApoorvaKalyani Jun 25, 2025
ee8c278
Added include file and cleaned up(as per review comment)
ApoorvaKalyani Jun 25, 2025
7840db4
Updated and optimized the example code for all types.
ApoorvaKalyani Jun 25, 2025
3037858
Fixed clang format
ApoorvaKalyani Jun 25, 2025
686df33
Resolve "Implement `device_gemm_bilinear` for RDNA4"
Jun 26, 2025
4f19101
Merge branch '63-implement-device_gemm_bilinear-for-rdna4' into 'feat…
Jun 26, 2025
6ba1dc6
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jun 30, 2025
eaa0452
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 6…
Jun 30, 2025
0794833
test generalization to handle FP16 shuffle better
Jun 30, 2025
bb7b307
added missing changes
Jun 30, 2025
35aab35
Added bf16 wmma instance for add_relu
ApoorvaKalyani Jun 19, 2025
6f89183
Added f16 wmma instance and corrected bf16 instance errors.
ApoorvaKalyani Jun 23, 2025
cdaff7f
Added instances to Cmake
ApoorvaKalyani Jun 24, 2025
6a116fa
Modified the template parameters to make the instances work.
ApoorvaKalyani Jul 1, 2025
bb7f665
Fixed typo in profiler
ApoorvaKalyani Jul 1, 2025
f5843dd
Added v3 instances for gemm_add_relu
ApoorvaKalyani Jul 1, 2025
ff31873
addressed core review comments
Jul 1, 2025
6ec0ad2
Added test for gemm_add_relu wmma instance
ApoorvaKalyani Jul 1, 2025
feca919
Cleaned up the code.
ApoorvaKalyani Jul 1, 2025
ba9c637
Added examples for gemm_add_relu
ApoorvaKalyani Jul 2, 2025
5c491e7
Fixing typo to resolve build errors.
ApoorvaKalyani Jul 2, 2025
8a5bb25
Fixes applied to fix the precision loss.
ApoorvaKalyani Jul 7, 2025
0551b84
fix billinear test after merge
Jul 8, 2025
86ca6b8
Removed the old wmma instances.
ApoorvaKalyani Jul 8, 2025
9b64da2
Added wrapper and renamed the wmma_v3 instances
ApoorvaKalyani Jul 8, 2025
669befb
Updated copyrights and added wrappers.
ApoorvaKalyani Jul 8, 2025
bdfdb0c
Fixes applied according to review comments
ApoorvaKalyani Jul 8, 2025
d3a26e5
Apply 1 suggestion(s) to 1 file(s)
ApoorvaKalyani Jul 8, 2025
84b0b32
Removed the old wmma instances.
ApoorvaKalyani Jul 8, 2025
516d1f5
Updated wrapper for the v3 instances
ApoorvaKalyani Jul 8, 2025
e59d281
removed the old wmma examples
ApoorvaKalyani Jul 8, 2025
566e472
Renamed the v3 instances
ApoorvaKalyani Jul 8, 2025
9655010
Deleted the gtest file added by mistake.
ApoorvaKalyani Jul 8, 2025
536f866
Updated thge profiler with wrapper
ApoorvaKalyani Jul 8, 2025
13efcc6
Fixed test errors.
ApoorvaKalyani Jul 8, 2025
55299c9
Fixed the review comments
ApoorvaKalyani Jul 9, 2025
3212507
Fixed the if condition MACROS.
ApoorvaKalyani Jul 9, 2025
21cb985
REVERTED THE PROFILER CHANGES
ApoorvaKalyani Jul 9, 2025
e1374ea
Revert "REVERTED THE PROFILER CHANGES"
ApoorvaKalyani Jul 9, 2025
9e3d87e
Revert "Fixed test errors."
ApoorvaKalyani Jul 9, 2025
ea133bf
Revert "Updated thge profiler with wrapper"
ApoorvaKalyani Jul 9, 2025
76f4bb0
Added missing wrapper instances
ApoorvaKalyani Jul 9, 2025
2738ca5
Updated copyrights.
ApoorvaKalyani Jul 9, 2025
e6ea4aa
Fixed typo.
ApoorvaKalyani Jul 9, 2025
8d64718
Fixed copyrights.
ApoorvaKalyani Jul 9, 2025
8e91755
Updated copyrights.
ApoorvaKalyani Jul 9, 2025
aea158f
updated copyrights.
ApoorvaKalyani Jul 9, 2025
a7993ab
comments on the atomics workaround
Jul 10, 2025
0dc871a
Merge branch '64-implement-device_gemm_multiply_multiply_instance-for…
Jul 11, 2025
41d4500
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jul 11, 2025
9c1314d
Merge branch '51-create-bf16-and-f16-instances-for-gemm_add-cshuffle_…
ApoorvaKalyani Jul 14, 2025
036799d
Merge branch '61-add-examples-for-bf16-and-fp16-instances-of-gemm_add…
ApoorvaKalyani Jul 14, 2025
27c0f95
Merge branch '79-add-instances-and-examples-for-device_gemm_add_relu'…
ApoorvaKalyani Jul 14, 2025
e2a75d6
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jul 14, 2025
161fe6c
fixed cmake comment
Jul 14, 2025
1de5d98
Merge branch '8-implement-device_gemm_add_multiply-for-rdna4' into 'f…
Jul 14, 2025
5dc21c5
Merge branch 'develop' into feature/multiple-d-gemms
EnricoDeg Jul 28, 2025
02cb1f2
Fix bug from merge
EnricoDeg Aug 4, 2025
ec38280
Merge remote-tracking branch 'origin/develop' into 90-prepare-an-upst…
krithalith Aug 6, 2025
c434378
clang-format-18
krithalith Aug 6, 2025
8f01112
Fix compilation error
EnricoDeg Aug 6, 2025
7eaf6fb
Fix linking error
EnricoDeg Aug 6, 2025
a2f03ec
Fix bug in add and add_relu examples
EnricoDeg Aug 6, 2025
29bf53e
Fix error including file (typo)
EnricoDeg Aug 6, 2025
4c4ab8b
Quick fix to compile examples for different targets
EnricoDeg Aug 6, 2025
7478497
Fix for multi target
EnricoDeg Aug 8, 2025
99a0a67
implemented f16 and bf16 instances for gemm_silu
kabraham-streamhpc Aug 4, 2025
e9bdd0c
addressed review comments
kabraham-streamhpc Aug 6, 2025
c027fba
addressed review comments
kabraham-streamhpc Aug 6, 2025
753d178
Fix clang format
EnricoDeg Aug 11, 2025
6168bef
Merge branch 'develop' into streamhpc/multiple-ds-based-gemms-wmma
EnricoDeg Aug 12, 2025
4038d34
Fix clang format
EnricoDeg Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
Expand Down
2 changes: 1 addition & 1 deletion example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
Expand Down
2 changes: 1 addition & 1 deletion example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
Expand Down
2 changes: 1 addition & 1 deletion example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
Expand Down
2 changes: 2 additions & 0 deletions example/65_gemm_multiply_multiply/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,5 @@ example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpresh

example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})

add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp)
267 changes: 267 additions & 0 deletions example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"

#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"

#include "ck/utility/blkgemmpipe_scheduler.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using A0DataType = F16;
using B0DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;

using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;

struct AddAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;

template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c + d0 + d1;

e = ck::type_convert<ck::half_t>(x0_f);
}
};

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAdd;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;

using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3
// clang-format off
//#########################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//#########################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| S<C, D..>| | |
< A0Layout, B0Layout, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on

int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;

// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = K;
ck::index_t StrideE = N;

if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);

M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);

StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}

auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;

if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};

Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));

std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;

switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}

DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());

a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());

auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};

constexpr ck::index_t NumDTensor = DsDataType::Size();

// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);

if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});

std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

float gb_per_sec = num_btype / 1.E6 / ave_time;

std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;

if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();

auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});

ref_invoker.Run(ref_argument);

for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}

e_device_buf.FromDevice(e_m_n_device_result.mData.data());

return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}

return 0;
}
14 changes: 6 additions & 8 deletions example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
Expand Down Expand Up @@ -184,7 +184,6 @@ int main(int argc, char* argv[])
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());

auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
Expand Down Expand Up @@ -220,11 +219,12 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});

std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

Expand All @@ -233,8 +233,6 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;

e_device_buf.FromDevice(e_m_n_device_result.mData.data());

if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Expand Down
22 changes: 22 additions & 0 deletions example/68_gemm_add/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
add_custom_target(example_gemm_add_xdl)

add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_fp16)


add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_bf16)

add_custom_target(example_gemm_add_wmma)

add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp)
add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_bf16)

add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp)
add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_fp16)






Loading