Skip to content

Commit 846de1f

Browse files
committed
Fix bug in approx::exp(bfloat16) for HIP
1 parent a2b08a5 commit 846de1f

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

include/kernel_float/approx.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ namespace kernel_float {
1010
namespace approx {
1111

1212
static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int");
13+
static_assert(sizeof(unsigned short) * 8 == 16, "invalid size of unsigned short");
1314
using uint32_t = unsigned int;
15+
using uint16_t = unsigned short;
1416

1517
template<typename T, typename U>
1618
KERNEL_FLOAT_DEVICE T transmute(const U& input) {
@@ -353,12 +355,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
353355
static constexpr float OFFSET = 382.4958400542335;
354356
static constexpr float MINIMUM = 382;
355357

356-
float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
357-
float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);
358+
float a = fmaxf(fmaf(__bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
359+
float b = fmaxf(fmaf(__bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);
358360

359361
return {
360-
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(a))),
361-
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(b)))};
362+
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(a))),
363+
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(b)))};
362364
}
363365
#endif
364366
} // namespace approx

single_include/kernel_float.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-11-26 14:20:49.081641
20-
// git hash: 76c695a4cc5b13b3d5841ac5085574a5b47a299c
19+
// date: 2024-12-02 10:59:19.296684
20+
// git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -4535,7 +4535,9 @@ namespace kernel_float {
45354535
namespace approx {
45364536

45374537
static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int");
4538+
static_assert(sizeof(unsigned short) * 8 == 16, "invalid size of unsigned short");
45384539
using uint32_t = unsigned int;
4540+
using uint16_t = unsigned short;
45394541

45404542
template<typename T, typename U>
45414543
KERNEL_FLOAT_DEVICE T transmute(const U& input) {
@@ -4878,12 +4880,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
48784880
static constexpr float OFFSET = 382.4958400542335;
48794881
static constexpr float MINIMUM = 382;
48804882

4881-
float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
4882-
float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);
4883+
float a = fmaxf(fmaf(__bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
4884+
float b = fmaxf(fmaf(__bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);
48834885

48844886
return {
4885-
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(a))),
4886-
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(b)))};
4887+
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(a))),
4888+
transmute<bfloat16_t>(uint16_t(transmute<uint32_t>(b)))};
48874889
}
48884890
#endif
48894891
} // namespace approx

0 commit comments

Comments
 (0)