-
Notifications
You must be signed in to change notification settings - Fork 795
[SYCL][Matrix] Add support for tf32 type using the unified interface #8702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
de3d230
[SYCL][Matrix] Add support for tf32 type using the unified interface
dkhaldi 88aad2c
Address Jack's comments
dkhaldi a56084d
change convert to round in the name
dkhaldi f7d7c4e
Merge remote-tracking branch 'intel_llvm/sycl' into tf32-joint-matrix
dkhaldi 0d23b8a
set xfail to test as SPIRV changes are not part of intel/llvm yet
dkhaldi 7d05e6a
correct SYCL_EXTERNAL with the new naming __DPCPP_SYCL_EXTERNAL
dkhaldi c162139
Merge remote-tracking branch 'intel_llvm/sycl' into tf32
yubingex007-a11y 1f84b04
move e2e testcases from llvm-test-suite's pr
yubingex007-a11y 93193e9
rm XFAIL from sycl/test/matrix/matrix-tf32-test.cpp
yubingex007-a11y 8caee53
fix clang-format issue
yubingex007-a11y 1557f34
address comments
yubingex007-a11y 19ada00
Merge remote-tracking branch 'intel_llvm/sycl' into tf32
yubingex007-a11y File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,10 +67,27 @@ struct joint_matrix; | |
|
||
} // namespace matrix | ||
} // namespace experimental | ||
|
||
namespace detail { | ||
// Differentiating between the "element type" and the "storage element type" | ||
template <typename T> struct helper_traits { | ||
using element_type = T; | ||
using storage_element_type = T; | ||
using fill_argument_type = T; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove fill_argument_type as it is currently unused |
||
}; | ||
|
||
template <> | ||
struct helper_traits<sycl::ext::oneapi::experimental::matrix::precision::tf32> { | ||
using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32; | ||
using storage_element_type = float; | ||
using fill_argument_type = float; | ||
}; | ||
} // namespace detail | ||
} // namespace oneapi | ||
|
||
namespace intel::experimental::matrix { | ||
|
||
using namespace sycl::ext::oneapi::experimental::matrix; | ||
// Begin wi_element definition | ||
|
||
template <typename T, size_t NumRows, size_t NumCols, | ||
|
@@ -84,13 +101,21 @@ class wi_element { | |
std::size_t idx; | ||
|
||
public: | ||
using storage_element_type = | ||
typename oneapi::detail::helper_traits<T>::storage_element_type; | ||
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
Group, T, Use, NumRows, NumCols, Layout> &Mat, | ||
std::size_t i) | ||
: M(Mat), idx(i) {} | ||
operator T() { | ||
operator storage_element_type() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return __spirv_VectorExtractDynamic(M.spvm, idx); | ||
storage_element_type elem = | ||
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols, | ||
spv_matrix_use_traits<Use>::value, | ||
spv_matrix_layout_traits<Layout>::value, | ||
spv_scope_traits<Group>::value>(M.spvm, | ||
idx); | ||
return elem; | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
|
@@ -99,7 +124,12 @@ class wi_element { | |
|
||
explicit operator bool() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0); | ||
return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, | ||
NumCols, | ||
spv_matrix_use_traits<Use>::value, | ||
spv_matrix_layout_traits<Layout>::value, | ||
spv_scope_traits<Group>::value>( | ||
M.spvm, idx) != static_cast<storage_element_type>(0); | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
|
@@ -108,7 +138,8 @@ class wi_element { | |
|
||
template <typename T2> wi_element &operator=(const T2 &rhs) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx); | ||
M.spvm = __spirv_VectorInsertDynamic( | ||
M.spvm, static_cast<storage_element_type>(rhs), idx); | ||
return *this; | ||
#else | ||
(void)rhs; | ||
|
@@ -121,7 +152,13 @@ class wi_element { | |
operator=(const wi_element<T, NumRows, NumCols, Use, Layout, Group> &rhs) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
M.spvm = __spirv_VectorInsertDynamic( | ||
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); | ||
M.spvm, | ||
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols, | ||
spv_matrix_use_traits<Use>::value, | ||
spv_matrix_layout_traits<Layout>::value, | ||
spv_scope_traits<Group>::value>(rhs.M.spvm, | ||
rhs.idx), | ||
idx); | ||
return *this; | ||
#else | ||
(void)rhs; | ||
|
@@ -135,8 +172,13 @@ class wi_element { | |
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \ | ||
M.spvm = __spirv_VectorInsertDynamic( \ | ||
M.spvm, \ | ||
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \ | ||
op static_cast<T>(rhs)), \ | ||
static_cast<storage_element_type>( \ | ||
__spirv_VectorExtractDynamic< \ | ||
storage_element_type, T, NumRows, NumCols, \ | ||
spv_matrix_use_traits<Use>::value, \ | ||
spv_matrix_layout_traits<Layout>::value, \ | ||
spv_scope_traits<Group>::value>(M.spvm, idx) \ | ||
op static_cast<storage_element_type>(rhs)), \ | ||
idx); \ | ||
return *this; \ | ||
} | ||
|
@@ -173,7 +215,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |
: M(Mat), idx(i) {} | ||
operator sycl::ext::oneapi::bfloat16() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return __spirv_VectorExtractDynamic(M.spvm, idx); | ||
return __spirv_VectorExtractDynamic< | ||
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, | ||
NumCols, spv_matrix_use_traits<Use>::value, | ||
spv_matrix_layout_traits<Layout>::value, | ||
spv_scope_traits<Group>::value>(M.spvm, idx); | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
|
@@ -182,8 +228,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |
|
||
explicit operator bool() { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic( | ||
M.spvm, idx))) >= std::numeric_limits<float>::epsilon(); | ||
return std::fabs(static_cast<float>( | ||
__spirv_VectorExtractDynamic< | ||
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, | ||
NumRows, NumCols, spv_matrix_use_traits<Use>::value, | ||
spv_matrix_layout_traits<Layout>::value, | ||
spv_scope_traits<Group>::value>(M.spvm, idx))) >= | ||
std::numeric_limits<float>::epsilon(); | ||
#else | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
|
@@ -205,7 +256,14 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |
NumCols, Use, Layout, Group> &rhs) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
M.spvm = __spirv_VectorInsertDynamic( | ||
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); | ||
M.spvm, | ||
__spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16, | ||
sycl::ext::oneapi::bfloat16, NumRows, | ||
NumCols, spv_matrix_use_traits<Use>::value, | ||
spv_matrix_layout_traits<Layout>::value, | ||
spv_scope_traits<Group>::value>(rhs.M.spvm, | ||
rhs.idx), | ||
idx); | ||
return *this; | ||
#else | ||
(void)rhs; | ||
|
@@ -218,7 +276,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |
#define OP(opassign, op) \ | ||
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \ | ||
M.spvm = __spirv_VectorInsertDynamic( \ | ||
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \ | ||
M.spvm, \ | ||
__spirv_VectorExtractDynamic< \ | ||
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \ | ||
NumCols, spv_matrix_use_traits<Use>::value, \ | ||
spv_matrix_layout_traits<Layout>::value, \ | ||
spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \ | ||
idx); \ | ||
return *this; \ | ||
} | ||
#else // __SYCL_DEVICE_ONLY__ | ||
|
@@ -241,13 +305,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \ | ||
Layout, Group> &lhs, \ | ||
const sycl::ext::oneapi::bfloat16 &rhs) { \ | ||
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \ | ||
return __spirv_VectorExtractDynamic< \ | ||
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \ | ||
NumCols, spv_matrix_use_traits<Use>::value, \ | ||
spv_matrix_layout_traits<Layout>::value, \ | ||
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \ | ||
} \ | ||
friend type operator op( \ | ||
const sycl::ext::oneapi::bfloat16 &lhs, \ | ||
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \ | ||
Layout, Group> &rhs) { \ | ||
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \ | ||
return __spirv_VectorExtractDynamic< \ | ||
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \ | ||
NumCols, spv_matrix_use_traits<Use>::value, \ | ||
spv_matrix_layout_traits<Layout>::value, \ | ||
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \ | ||
} | ||
OP(sycl::ext::oneapi::bfloat16, +) | ||
OP(sycl::ext::oneapi::bfloat16, -) | ||
|
@@ -259,15 +331,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \ | ||
Layout, Group> &lhs, \ | ||
const sycl::ext::oneapi::bfloat16 &rhs) { \ | ||
return type{static_cast<float>(__spirv_VectorExtractDynamic( \ | ||
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \ | ||
return type{static_cast<float>( \ | ||
__spirv_VectorExtractDynamic< \ | ||
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \ | ||
NumCols, spv_matrix_use_traits<Use>::value, \ | ||
spv_matrix_layout_traits<Layout>::value, \ | ||
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \ | ||
op static_cast<float>(rhs)}; \ | ||
} \ | ||
friend type operator op( \ | ||
const sycl::ext::oneapi::bfloat16 &lhs, \ | ||
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \ | ||
Layout, Group> &rhs) { \ | ||
return type{static_cast<float>(__spirv_VectorExtractDynamic( \ | ||
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \ | ||
return type{static_cast<float>( \ | ||
__spirv_VectorExtractDynamic< \ | ||
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \ | ||
NumCols, spv_matrix_use_traits<Use>::value, \ | ||
spv_matrix_layout_traits<Layout>::value, \ | ||
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \ | ||
op static_cast<float>(lhs)}; \ | ||
} | ||
OP(bool, ==) | ||
OP(bool, !=) | ||
|
@@ -358,7 +440,7 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< | |
// End wi_data definition | ||
|
||
template < | ||
typename Group, typename T, | ||
typename Group, typename T, typename Tp, | ||
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows, | ||
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout, | ||
access::address_space Space, access::decorated IsDecorated, | ||
|
@@ -368,7 +450,7 @@ template < | |
inline __SYCL_ALWAYS_INLINE void | ||
joint_matrix_store(Group sg, | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
Group, T, Use, NumRows, NumCols, Layout> &src, | ||
Group, Tp, Use, NumRows, NumCols, Layout> &src, | ||
multi_ptr<T, Space, IsDecorated> dst, size_t stride) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
|
@@ -383,7 +465,7 @@ joint_matrix_store(Group sg, | |
#else | ||
// intel's impl | ||
T *Ptr = dst.get(); | ||
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols, | ||
__spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols, | ||
sycl::ext::oneapi::experimental::matrix:: | ||
spv_matrix_use_traits<Use>::value, | ||
sycl::ext::oneapi::experimental::matrix:: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.