diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc index d6de22bdae391..d861db1e141fe 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc @@ -304,15 +304,20 @@ q.submit([&](sycl::handler& cgh) { joint_matrix tC; joint_matrix_fill(sg, tC, 0); for (int k = 0; k < K; k += tK) { - joint_matrix_load(sg, tA, accA + sg_startx * tM * K + k, K); - joint_matrix_load(sg, tB, accB + k * N*4 + sg_starty/SG_SIZE*tN*4, N*4); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_load(sg, tA, + accA.template get_multi_ptr() + + sg_startx * tM * K + k, K); + joint_matrix_load(sg, tB, + accB.template get_multi_ptr() + + k * N*4 + sg_starty/SG_SIZE*tN*4, N*4); + joint_matrix_mad(sg, tC, tA, tB, tC); } - auto wi_data_c = ext::intel::experimental::matrix::get_wi_data(sg, tC); - for (int i = 0; i < wi_data_c.length(); i++) - wi_data_c[i] *= alpha; + joint_matrix_apply(sg, tC, [=](int8_t x) { + x *= alpha; + }); joint_matrix_store(sg, tC, - accC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major); + accC.template get_multi_ptr() + + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major); }); }); q.wait(); diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index 94c2bebe04906..65c8508eab668 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -274,11 +274,11 @@ rows for the row major layout, or between columns for the column major layout. ```c++ namespace sycl::ext::oneapi::experimental::matrix { -template -joint_matrix -joint_matrix_mad(Group g, +template +void joint_matrix_mad(Group g, + joint_matrix &D, const joint_matrix &A, const joint_matrix &B, const joint_matrix &C); @@ -287,7 +287,7 @@ joint_matrix_mad(Group g, ``` The matrix multiply and add function performs the multiply operation on the matrices `A` and `B`, accumulates the result with `C` and returns -the result. +the result into the matrix `D`. Each device supports only certain combinations of types for the `A`, `B`, and `C` matrices. The application must use the query operations @@ -505,6 +505,12 @@ range<2> L = {1, SG_SIZE}; int8_t *memA = malloc_shared(M*K, q); int8_t *memB = malloc_shared(K*N, q); int32_t *memC = malloc_shared(M*N, q); +auto pA = address_space_cast(memA); +auto pB = address_space_cast(memB); +auto pC = address_space_cast(memC); q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item) [[sycl::reqd_sub_group_size(SG_SIZE)]] { const auto global_idx = item.get_global_id(0); @@ -517,20 +523,15 @@ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item) joint_matrix tC; joint_matrix_fill(sg, tC, 0); for (int k = 0; k < K; k += tK) { - joint_matrix_load(sg, tA, - multi_ptr(memA) + - sg_startx * tM * K + k, K); - joint_matrix_load(sg, tB, - multi_ptr(memB) + - k * N + sg_starty/SG_SIZE*tN, N); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K); + joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N); + joint_matrix_mad(sg, tC, tA, tB, tC); } joint_matrix_apply(sg, tC, [=](int8_t x) { x *= alpha; }); - joint_matrix_store(sg, tC, - multi_ptr(memC) + - sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major); + joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, + N, layout::row_major); }).wait(); ```