Skip to content

Commit afd5c6d

Browse files
authored
using syrk for performing special cases of matrix multiplication (#2509)
In this PR, the `syrk` routines from oneMKL is used to perform a rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix.
1 parent 1a7ce22 commit afd5c6d

17 files changed

+703
-165
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
* Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478)
1212
* Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500)
1313
* Added implementation of `dpnp.ndarray.view` method [#2520](https://github.com/IntelPython/dpnp/pull/2520)
14+
* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509)
1415

1516
### Changed
1617

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set(_module_src
3030
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp
3334
)
3435

3536
pybind11_add_module(${python_module_name} MODULE ${_module_src})
@@ -61,6 +62,7 @@ set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDEN
6162

6263
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
6364
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
65+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
6466

6567
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
6668
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "dotu.hpp"
3737
#include "gemm.hpp"
3838
#include "gemv.hpp"
39+
#include "syrk.hpp"
3940

4041
namespace blas_ns = dpnp::extensions::blas;
4142
namespace py = pybind11;
@@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void)
4849
blas_ns::init_gemm_batch_dispatch_table();
4950
blas_ns::init_gemm_dispatch_table();
5051
blas_ns::init_gemv_dispatch_vector();
52+
blas_ns::init_syrk_dispatch_vector();
5153
}
5254

5355
static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
@@ -73,7 +75,7 @@ PYBIND11_MODULE(_blas_impl, m)
7375
};
7476

7577
m.def("_dot", dot_pyapi,
76-
"Call `dot` from OneMKL BLAS library to compute "
78+
"Call `dot` from oneMKL BLAS library to compute "
7779
"the dot product of two real-valued vectors.",
7880
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
7981
py::arg("result"), py::arg("depends") = py::list());
@@ -91,7 +93,7 @@ PYBIND11_MODULE(_blas_impl, m)
9193
};
9294

9395
m.def("_dotc", dotc_pyapi,
94-
"Call `dotc` from OneMKL BLAS library to compute "
96+
"Call `dotc` from oneMKL BLAS library to compute "
9597
"the dot product of two complex vectors, "
9698
"conjugating the first vector.",
9799
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -110,37 +112,45 @@ PYBIND11_MODULE(_blas_impl, m)
110112
};
111113

112114
m.def("_dotu", dotu_pyapi,
113-
"Call `dotu` from OneMKL BLAS library to compute "
115+
"Call `dotu` from oneMKL BLAS library to compute "
114116
"the dot product of two complex vectors.",
115117
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
116118
py::arg("result"), py::arg("depends") = py::list());
117119
}
118120

119121
{
120122
m.def("_gemm", &blas_ns::gemm,
121-
"Call `gemm` from OneMKL BLAS library to compute "
123+
"Call `gemm` from oneMKL BLAS library to compute "
122124
"the matrix-matrix product with 2-D matrices.",
123125
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
124126
py::arg("resultC"), py::arg("depends") = py::list());
125127
}
126128

127129
{
128130
m.def("_gemm_batch", &blas_ns::gemm_batch,
129-
"Call `gemm_batch` from OneMKL BLAS library to compute "
131+
"Call `gemm_batch` from oneMKL BLAS library to compute "
130132
"the matrix-matrix product for a batch of 2-D matrices.",
131133
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
132134
py::arg("resultC"), py::arg("depends") = py::list());
133135
}
134136

135137
{
136138
m.def("_gemv", &blas_ns::gemv,
137-
"Call `gemv` from OneMKL BLAS library to compute "
139+
"Call `gemv` from oneMKL BLAS library to compute "
138140
"the matrix-vector product with a general matrix.",
139141
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
140142
py::arg("vectorY"), py::arg("transpose"),
141143
py::arg("depends") = py::list());
142144
}
143145

146+
{
147+
m.def("_syrk", &blas_ns::syrk,
148+
"Call `syrk` from oneMKL BLAS library to compute "
149+
"the matrix-vector product with a general matrix.",
150+
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"),
151+
py::arg("depends") = py::list());
152+
}
153+
144154
{
145155
m.def(
146156
"_using_onemath",

dpnp/backend/extensions/blas/dot_common.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ std::pair<sycl::event, sycl::event>
128128
dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
129129
if (dot_fn == nullptr) {
130130
throw py::value_error(
131-
"Types of input vectors and result array are mismatched.");
131+
"No dot implementation is available for the specified data type "
132+
"of the input and output arrays.");
132133
}
133134

134135
char *x_typeless_ptr = vectorX.get_data();

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
119119
Tab(1), // Scaling factor for the product of matrices A and B.
120120
a, // Pointer to matrix A.
121121
lda, // Leading dimension of matrix A, which is the
122-
// stride between successive rows (for row major
123-
// layout).
122+
// stride between successive rows (for row major layout).
124123
b, // Pointer to matrix B.
125124
ldb, // Leading dimension of matrix B, similar to lda.
126125
Tab(0), // Scaling factor for matrix C.
@@ -158,7 +157,8 @@ std::tuple<sycl::event, sycl::event, bool>
158157
const int resultC_nd = resultC.get_ndim();
159158

160159
if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) {
161-
throw py::value_error("Input matrices must be two-dimensional.");
160+
throw py::value_error(
161+
"Input and output matrices must be two-dimensional.");
162162
}
163163

164164
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
@@ -276,6 +276,8 @@ std::tuple<sycl::event, sycl::event, bool>
276276
}
277277
}
278278
else {
279+
// both A and B are f_contig so using column-major gemm and
280+
// no transpose is needed
279281
transA = oneapi::mkl::transpose::N;
280282
transB = oneapi::mkl::transpose::N;
281283
lda = m;
@@ -303,7 +305,8 @@ std::tuple<sycl::event, sycl::event, bool>
303305
gemm_dispatch_table[matrixAB_type_id][resultC_type_id];
304306
if (gemm_fn == nullptr) {
305307
throw py::value_error(
306-
"Types of input matrices and result matrix are mismatched.");
308+
"No gemm implementation is available for the specified data type "
309+
"of the input and output arrays.");
307310
}
308311

309312
const char *a_typeless_ptr = matrixA.get_data();

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ std::tuple<sycl::event, sycl::event, bool>
379379
gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
380380
if (gemm_batch_fn == nullptr) {
381381
throw py::value_error(
382-
"Types of input matrices and result matrix are mismatched.");
382+
"No gemm_batch implementation is available for the specified data "
383+
"type of the input and output arrays.");
383384
}
384385

385386
const char *a_typeless_ptr = matrixA.get_data();

dpnp/backend/extensions/blas/gemv.cpp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
109109
T(1), // Scaling factor for the matrix-vector product.
110110
a, // Pointer to the input matrix A.
111111
lda, // Leading dimension of matrix A, which is the
112-
// stride between successive rows (for row major
113-
// layout).
112+
// stride between successive rows (for row major layout).
114113
x, // Pointer to the input vector x.
115114
incx, // The stride of vector x.
116115
T(0), // Scaling factor for vector y.
@@ -181,6 +180,26 @@ std::pair<sycl::event, sycl::event>
181180
const py::ssize_t *a_shape = matrixA.get_shape_raw();
182181
const py::ssize_t *x_shape = vectorX.get_shape_raw();
183182
const py::ssize_t *y_shape = vectorY.get_shape_raw();
183+
if (transpose) {
184+
if (a_shape[0] != x_shape[0]) {
185+
throw py::value_error("The number of rows in A must be equal to "
186+
"the number of elements in X.");
187+
}
188+
if (a_shape[1] != y_shape[0]) {
189+
throw py::value_error("The number of columns in A must be equal to "
190+
"the number of elements in Y.");
191+
}
192+
}
193+
else {
194+
if (a_shape[1] != x_shape[0]) {
195+
throw py::value_error("The number of columns in A must be equal to "
196+
"the number of elements in X.");
197+
}
198+
if (a_shape[0] != y_shape[0]) {
199+
throw py::value_error("The number of rows in A must be equal to "
200+
"the number of elements in Y.");
201+
}
202+
}
184203

185204
oneapi::mkl::transpose transA;
186205
std::size_t src_nelems;
@@ -234,27 +253,6 @@ std::pair<sycl::event, sycl::event>
234253
}
235254
#endif // USE_ONEMATH_CUBLAS
236255

237-
if (transpose) {
238-
if (a_shape[0] != x_shape[0]) {
239-
throw py::value_error("The number of rows in A must be equal to "
240-
"the number of elements in X.");
241-
}
242-
if (a_shape[1] != y_shape[0]) {
243-
throw py::value_error("The number of columns in A must be equal to "
244-
"the number of elements in Y.");
245-
}
246-
}
247-
else {
248-
if (a_shape[1] != x_shape[0]) {
249-
throw py::value_error("The number of columns in A must be equal to "
250-
"the number of elements in X.");
251-
}
252-
if (a_shape[0] != y_shape[0]) {
253-
throw py::value_error("The number of rows in A must be equal to "
254-
"the number of elements in Y.");
255-
}
256-
}
257-
258256
const std::int64_t lda = is_row_major ? n : m;
259257
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY);
260258
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY,
@@ -275,10 +273,11 @@ std::pair<sycl::event, sycl::event>
275273
gemv_impl_fn_ptr_t gemv_fn = gemv_dispatch_vector[type_id];
276274
if (gemv_fn == nullptr) {
277275
throw py::value_error(
278-
"Types of input arrays and result array are mismatched.");
276+
"No gemv implementation is available for the specified data type "
277+
"of the input and output arrays.");
279278
}
280279

281-
char *a_typeless_ptr = matrixA.get_data();
280+
const char *a_typeless_ptr = matrixA.get_data();
282281
char *x_typeless_ptr = vectorX.get_data();
283282
char *y_typeless_ptr = vectorY.get_data();
284283

dpnp/backend/extensions/blas/gemv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,4 @@ extern std::pair<sycl::event, sycl::event>
4141
const std::vector<sycl::event> &depends);
4242

4343
extern void init_gemv_dispatch_vector(void);
44-
extern void init_gemv_batch_dispatch_vector(void);
4544
} // namespace dpnp::extensions::blas

0 commit comments

Comments
 (0)