diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6d69f9e --- /dev/null +++ b/.gitignore @@ -0,0 +1,120 @@ +# Editor temporary/working/backup files # +######################################### +.#* +[#]*# +*~ +*$ +*.bak +*.diff +.idea/ +*.iml +*.ipr +*.iws +*.org +.project +pmip +*.rej +.settings/ +.*.sw[nop] +.sw[nop] +*.tmp +*.vim +.vscode +tags +cscope.out +# gnu global +GPATH +GRTAGS +GSYMS +GTAGS +.cache +.mypy_cache/ + +# Compiled source # +################### +*.a +*.com +*.class +*.dll +*.exe +*.o +*.o.d +*.py[ocd] +*.so +*.mod + +# Packages # +############ +# it's better to unpack these files and commit the raw source +# git has its own built in compression methods +*.7z +*.bz2 +*.bzip2 +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar +*.tbz2 +*.tgz +*.zip + +# Python files # +################ +# meson build/installation directories +build +build-install +# meson python output +.mesonpy-native-file.ini +# sphinx build directory +_build +# dist directory is where sdist/wheel end up +dist +doc/build +doc/docenv +doc/cdoc/build +# Egg metadata +*.egg-info +# The shelf plugin uses this dir +./.shelf +.cache +pip-wheel-metadata +.python-version +# virtual envs +numpy-dev/ +venv/ + +# Paver generated files # +######################### +/release + +# Logs and databases # +###################### +*.log +*.sql +*.sqlite + +# Patches # +########### +*.patch +*.diff + +# Do not ignore the following patches: # +######################################## +!tools/ci/emscripten/0001-do-not-set-meson-environment-variable-pyodide-gh-4502.patch + +# OS generated files # +###################### +.DS_Store* +.VolumeIcon.icns +.fseventsd +Icon? +.gdb_history +ehthumbs.db +Thumbs.db +.directory + +# pytest generated files # +########################## +/.pytest_cache diff --git a/.spin/cmds.py b/.spin/cmds.py new file mode 100644 index 0000000..6376c2a --- /dev/null +++ b/.spin/cmds.py @@ -0,0 +1,18 @@ +import os +import pathlib +import sys +import click +import spin +from spin.cmds import meson + +curdir = pathlib.Path(__file__).parent +rootdir = curdir.parent + + +@click.command(help="Generate sollya python based files") +@click.option("-f", "--force", is_flag=True, help="Force regenerate all files") +def generate(*, force): + spin.util.run( + ["python", str(rootdir / "tools" / "sollya" / "generate.py")] + + (["--force"] if force else []), + ) diff --git a/npsr/common.h b/npsr/common.h new file mode 100644 index 0000000..9674212 --- /dev/null +++ b/npsr/common.h @@ -0,0 +1,37 @@ +#ifndef NUMPY_SIMD_ROUTINES_NPSR_COMMON_H_ +#define NUMPY_SIMD_ROUTINES_NPSR_COMMON_H_ + +#include + +#include +#include + +#include "precise.h" + +#endif // NUMPY_SIMD_ROUTINES_NPSR_COMMON_H_ + +#if defined(NUMPY_SIMD_ROUTINES_NPSR_COMMON_FOREACH_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NUMPY_SIMD_ROUTINES_NPSR_COMMON_FOREACH_H_ +#undef NUMPY_SIMD_ROUTINES_NPSR_COMMON_FOREACH_H_ +#else +#define NUMPY_SIMD_ROUTINES_NPSR_COMMON_FOREACH_H_ +#endif + +HWY_BEFORE_NAMESPACE(); +namespace npsr::HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; +using hn::DFromV; +using hn::MFromD; +using hn::Rebind; +using hn::RebindToUnsigned; +using hn::TFromD; +using hn::TFromV; +using hn::VFromD; +constexpr bool kNativeFMA = HWY_NATIVE_FMA != 0; + +HWY_ATTR void DummyToSuppressUnusedWarning() {} +} // namespace npsr::HWY_NAMESPACE +HWY_AFTER_NAMESPACE(); + +#endif // NUMPY_SIMD_ROUTINES_NPSR_COMMON_FOREACH_H_ diff --git a/npsr/npsr.h b/npsr/npsr.h new file mode 100644 index 0000000..9ca3af7 --- /dev/null +++ b/npsr/npsr.h @@ -0,0 +1,12 @@ +// To include them once per target, which is ensured by the toggle check. +// clang-format off +#if defined(_NPSR_NPSR_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef _NPSR_NPSR_H_ +#undef _NPSR_NPSR_H_ +#else +#define _NPSR_NPSR_H_ +#endif + +#include "npsr/trig/inl.h" + +#endif // _NPSR_NPSR_H_ diff --git a/npsr/precise.h b/npsr/precise.h new file mode 100644 index 0000000..e229f9f --- /dev/null +++ b/npsr/precise.h @@ -0,0 +1,151 @@ +#ifndef NUMPY_SIMD_ROUTINES_NPSR_PRECISE_H_ +#define NUMPY_SIMD_ROUTINES_NPSR_PRECISE_H_ +#include +#include + +namespace npsr { + +struct _NoLargeArgument {}; +struct _NoSpecialCases {}; +struct _NoExceptions {}; +struct _LowAccuracy {}; +constexpr auto kNoLargeArgument = _NoLargeArgument{}; +constexpr auto kNoSpecialCases = _NoSpecialCases{}; +constexpr auto kNoExceptions = _NoExceptions{}; +constexpr auto kLowAccuracy = _LowAccuracy{}; + +struct Round { + struct _Force {}; + static constexpr auto kForce = _Force{}; +}; + +struct Subnormal { + struct _DAZ {}; + struct _FTZ {}; + struct _IEEE754 {}; + static constexpr auto kDAZ = _DAZ{}; + static constexpr auto kFTZ = _FTZ{}; + static constexpr auto kIEEE754 = _IEEE754{}; +}; + +struct FPExceptions { + static constexpr auto kNone = 0; + static constexpr auto kInvalid = FE_INVALID; + static constexpr auto kDivByZero = FE_DIVBYZERO; + static constexpr auto kOverflow = FE_OVERFLOW; + static constexpr auto kUnderflow = FE_UNDERFLOW; +}; + +/** + * @brief RAII floating-point precision control class + * + * The Precise class provides automatic management of floating-point + * environment settings during its lifetime. It uses RAII principles to save + * the current floating-point state on construction and restore it on + * destruction. + * + * The class is configured using variadic template arguments that specify + * the desired floating-point behavior through tag types. + * + * **IMPORTANT PERFORMANCE NOTE**: Create the Precise object BEFORE loops, + * not inside them. The constructor and destructor have overhead from saving + * and restoring floating-point state, so it should be done once per + * computational scope, not per iteration. + * + * @tparam Args Variadic template arguments for configuration flags + * + * @example + * ```cpp + * using namespace hwy::HWY_NAMESPACE; + * using namespace npsr; + * using namespace npsr::HWY_NAMESPACE; + * + * Precise precise = {kLowAccuracy, kNoSpecialCases, kNoLargeArgument}; + * const ScalableTag d; + * typename V = Vec>; + * for (size_t i = 0; i < n; i += Lanes(d)) { + * V input = LoadU(d, &input[i]); + * V result = Sin(precise, input); + * StoreU(result, d, &output[i]); + * } + * ``` + */ +template +class Precise { + public: + Precise() { + if constexpr (!kNoExceptions) { + fegetexceptflag(&_exceptions, FE_ALL_EXCEPT); + } + if constexpr (kRoundForce) { + _rounding_mode = fegetround(); + int new_mode = _NewRoundingMode(); + if (_rounding_mode != new_mode) { + _retrieve_rounding_mode = true; + fesetround(new_mode); + } + } + } + template + Precise(T1&& arg1, Rest&&... rest) {} + + void FlushExceptions() { fesetexceptflag(&_exceptions, FE_ALL_EXCEPT); } + + void Raise(int errors) { + static_assert(!kNoExceptions, + "Cannot raise exceptions in NoExceptions mode"); + _exceptions |= errors; + } + ~Precise() { + FlushExceptions(); + if constexpr (kRoundForce) { + if (_retrieve_rounding_mode) { + fesetround(_rounding_mode); + } + } + } + static constexpr bool kNoExceptions = + (std::is_same_v<_NoExceptions, Args> || ...); + static constexpr bool kNoLargeArgument = + (std::is_same_v<_NoLargeArgument, Args> || ...); + static constexpr bool kNoSpecialCases = + (std::is_same_v<_NoSpecialCases, Args> || ...); + static constexpr bool kLowAccuracy = + (std::is_same_v<_LowAccuracy, Args> || ...); + // defaults to high accuracy if no low accuracy flag is set + static constexpr bool kHighAccuracy = !kLowAccuracy; + // defaults to large argument support if no no large argument flag is set + static constexpr bool kLargeArgument = !kNoLargeArgument; + // defaults to special cases support if no no special cases flag is set + static constexpr bool kSpecialCases = !kNoSpecialCases; + // defaults to exception support if no no exception flag is set + static constexpr bool kExceptions = !kNoExceptions; + + static constexpr bool kRoundForce = + (std::is_same_v || ...); + + static constexpr bool kDAZ = (std::is_same_v || ...); + static constexpr bool kFTZ = (std::is_same_v || ...); + static constexpr bool _kIEEE754 = + (std::is_same_v || ...); + static_assert(!_kIEEE754 || !(kDAZ || kFTZ), + "IEEE754 mode cannot be used " + "with Denormals Are Zero (DAZ) or Flush To Zero (FTZ) " + "subnormal handling"); + static constexpr bool kIEEE754 = _kIEEE754 || !(kDAZ || kFTZ); + + private: + int _NewRoundingMode() const { return FE_TONEAREST; } + int _rounding_mode = 0; + bool _retrieve_rounding_mode = false; + fexcept_t _exceptions; +}; + +Precise() -> Precise<>; + +// For Precise{args...} -> Precise +template +Precise(T1&&, Rest&&...) -> Precise, std::decay_t...>; + +} // namespace npsr +#endif // NUMPY_SIMD_ROUTINES_NPSR_PRECISE_H_ diff --git a/npsr/trig/data/approx.h.sol b/npsr/trig/data/approx.h.sol new file mode 100644 index 0000000..86693e1 --- /dev/null +++ b/npsr/trig/data/approx.h.sol @@ -0,0 +1,52 @@ +suppressmessage(186, 185, 184); + +procedure ApproxLut4_(pT, pFunc, pFuncDriv) { + var r, i, $; + + $.num_lut = match pT.kSize + with 64: (2^9) + default: (2^8); + + $.low_round = match pT.kSize + with 64: ([|24, RZ|]) + default: ([|pT.kDigits, RN|]); + $.scale = 2.0 * pi / $.num_lut; + + r = [||]; + for i from 0 to $.num_lut - 1 do { + $.angle = i * $.scale; + $.exact = pFunc($.angle); + $.high = pT.kRound($.exact); + $.low = pT.kRound(round($.exact - $.high, $.low_round[0], $.low_round[1])); + + $.deriv_exact = pFuncDriv($.angle); + $.k = ceil(log2(abs($.deriv_exact))); + if ($.deriv_exact < 0) then $.k = -$.k; + + $.sigma = 2.0^$.k; + $.deriv = pT.kRound($.deriv_exact - $.sigma); + r = r @ [|$.deriv, $.sigma, $.high, $.low|]; + }; + return ToStringCArray(r, pT.kCSFX, 4); +}; + +Append( + "template constexpr char kSinApproxTable[] = {};", + "template <> constexpr float kSinApproxTable[] = ", + ApproxLut4_(Float32, sin(x), cos(x)), + "", + "template <> constexpr double kSinApproxTable[] = ", + ApproxLut4_(Float64, sin(x), cos(x)), + "" +); +Append( + "template constexpr char kCosApproxTable[] = {};", + "template <> constexpr float kCosApproxTable[] = ", + ApproxLut4_(Float32, cos(x), -sin(x)), + "", + "template <> constexpr double kCosApproxTable[] = ", + ApproxLut4_(Float64, cos(x), -sin(x)), + "" +); + +WriteCPPHeader("npsr::trig::data"); diff --git a/npsr/trig/data/constants.h.sol b/npsr/trig/data/constants.h.sol new file mode 100644 index 0000000..da78375 --- /dev/null +++ b/npsr/trig/data/constants.h.sol @@ -0,0 +1,87 @@ +procedure ConstantsToArrayF32_(pArgs = ...) { + return ToStringCArray(ConstantsFromArray(pArgs), "f", 4); +}; +procedure ConstantsToArrayF64_(pArgs = ...) { + return ToStringCArray(ConstantsFromArray(pArgs), "", 4); +}; + +Append( + "template constexpr char kPi[] = {};", + + "template <> constexpr float kPi[] = " @ + ConstantsToArrayF32_(pi, [|RN, 24, 24, 24|]), + "template <> constexpr float kPi[] = " @ + ConstantsToArrayF32_(pi, [|RD, 11, 11, 11|], [|RN, 24|]), // no FMA + + + "template <> constexpr double kPi[] = " @ + ConstantsToArrayF64_(pi, [|RN, 53|], [|RD, 53|], [|RU, 53|]), + "template <> constexpr double kPi[] = " @ + ConstantsToArrayF64_(pi, [|RN, 24, 24, 24|], [|RN, 53|]), // no FMA + + "" +); + +Append( + "template constexpr double kPiPrec35[] = " @ + ConstantsToArrayF64_(pi, [|RN, 35|], [|RD, 53|]), + "template <> constexpr double kPiPrec35[] = " @ + ConstantsToArrayF64_(pi, [|RN, 24, 24, 24|]), + "" +); + +Append( + "template constexpr char kPiMul2[] = {};", + + "template <> constexpr float kPiMul2[] = " @ + ConstantsToArrayF32_(pi*2, [|RN, 24, 24|]), + "template <> constexpr double kPiMul2[] = " @ + ConstantsToArrayF64_(pi*2, [|RN, 53, 53|]), + "" +); + +vNFma = Constants(pi/16, [|RN, 27, 27|], [|RN, 29|], [|RN, 53|]); +Append( + "template constexpr double kPiDiv16Prec29[] = " @ + ConstantsToArrayF64_(pi/16, [|RN, 53|], [|RN, 29|], [|RN, 53|]), + "template <> constexpr double kPiDiv16Prec29[] = " @ + ToStringCArray([|vNFma[0], vNFma[2], vNFma[3], vNFma[1]|], "", 4), + "" +); + +Append( + "template constexpr char kInvPi = '_';", + "template <> constexpr float kInvPi = " @ + single(1/pi) @ "f;", + + "template <> constexpr double kInvPi = " @ + double(1/pi) @ ";", + "" +); + +Append( + "template constexpr char kHalfPi = '_';", + + "template <> constexpr float kHalfPi = " @ + single(pi/2) @ "f;", + + "template <> constexpr double kHalfPi = " @ + double(pi/2) @ ";", + "" +); + +Append( + "template constexpr char k16DivPi = '_';", + + "template <> constexpr float k16DivPi = " @ + single(16/pi) @ "f;", + + "template <> constexpr double k16DivPi = " @ + double(16/pi) @ ";", + "" +); + +// Dump(); + +WriteCPPHeader("npsr::trig::data"); + diff --git a/npsr/trig/data/data.h.sol b/npsr/trig/data/data.h.sol new file mode 100644 index 0000000..3c393fc --- /dev/null +++ b/npsr/trig/data/data.h.sol @@ -0,0 +1,10 @@ +var header; +for header in [|"constants", "high", "approx", "reduction"|] do { + Append( + "#include \"npsr/trig/data/" @ header @ ".h\"" + ); +}; + +WriteCPPHeader(); + + diff --git a/npsr/trig/data/high.h.sol b/npsr/trig/data/high.h.sol new file mode 100644 index 0000000..d29151a --- /dev/null +++ b/npsr/trig/data/high.h.sol @@ -0,0 +1,54 @@ +procedure PiDivTable_(pT, pFunc, pBy) { + var r, i, pi_by; + pi_by = pi / pBy; + r = [||]; + for i from 0 to pBy - 1 do { + r = r :. pT.kRound(pFunc(i * pi_by)); + }; + return ToStringCArray(r, pT.kCSFX, 1); +}; + +procedure PiDivPackLowTable_(pT, pFunc0, pFunc1, pBy) { + var r, i, digits, $; + $.pi_by = pi / pBy; + r = [||]; + for i from 0 to pBy - 1 do { + $.hi0 = pT.kRound(pFunc0(i * $.pi_by)); + $.hi1 = pT.kRound(pFunc1(i * $.pi_by)); + $.hi0_low = pT.kRound(pFunc0(i * $.pi_by) - $.hi0); + $.hi1_low = pT.kRound(pFunc1(i * $.pi_by) - $.hi1); + r = r @ [|$.hi0_low, $.hi1_low|]; + }; + digits = ToDigits(pT, r); + $.half_size = pT.kSize / 2; + $.lower_bits = 2^$.half_size; + r = [||]; + for i from 0 to length(digits) - 1 by 2 do { + $.hi0 = digits[i]; + $.hi1 = digits[i + 1]; + // F64: (hi1 & 0xFFFFFFFF00000000) | ((hi0 >> 32) & 0xFFFFFFFF) + $.pack = mod(RightShift($.hi0, $.half_size), $.lower_bits); + $.pack = $.pack + $.hi1 - mod($.hi1, $.lower_bits); + r = r :. $.pack; + }; + r = FromDigits(pT, r); + return ToStringCArray(r, "", 4); +}; + +Append( + "constexpr double kHiSinKPi16Table[] = " @ + PiDivTable_(Float64, sin(x), 16), + "", + "constexpr double kHiCosKPi16Table[] = " @ + PiDivTable_(Float64, cos(x), 16), + "" +); + +Append( + "constexpr double kPackedLowSinCosKPi16Table[] = " @ + PiDivPackLowTable_(Float64, sin(x), cos(x), 16), + "" +); + +// Dump(); +WriteCPPHeader("npsr::trig::data"); diff --git a/npsr/trig/data/reduction.h.sol b/npsr/trig/data/reduction.h.sol new file mode 100644 index 0000000..9a1519a --- /dev/null +++ b/npsr/trig/data/reduction.h.sol @@ -0,0 +1,47 @@ +// Precompute int(2^exp × 4/π) with ~96-bit precision (f32) or ~192-bit precision (f64) +// and split them into three chunks: 32-bit chunks for single precision, 64-bit chunks for double precision. +// +// This generates a lookup table for large range reduction in trigonometric functions. +// The table is used to compute mantissa × (2^exp × 4/π) using wider integer multiplications for precision: +// - f32: 16×16 → 32-bit multiplications +// - f64: 32×32 → 64-bit multiplications +// +// For input x = mantissa × 2^exp, the algorithm becomes: +// x × 4/π = mantissa × table_lookup[exp], providing high precision without floating-point errors. +// +// Args: +// float_size: 32 for f32 or 64 for f64 +procedure ReductionTuble_(pT, pOffset) { + var r, i, j, $; + SetDisplay(decimal); + SetPrec(pT.kDigits * 3); + $.mask = 2^pT.kSize; + $.scalar = 4 / pi; + r = [||]; + for i from 0 to pT.kMaxExpBiased + 1 do { + $.exp_shift = i - pT.kBias + pOffset; + $._int = LeftShift($.scalar, $.exp_shift); + $.chunks = [||]; + for j in [|pT.kSize * 2, pT.kSize, 0|] do { + $.rshift = RightShift($._int, j); + $.apply_mask = mod($.rshift, $.mask); + $.chunks = $.chunks @ [|$.apply_mask|]; + }; + r = r @ $.chunks; + }; + r = ToStringCArray(r, pT.kCUintSFX, 3); + RestorePrec(); + RestoreDisplay(); + return r; +}; + +Append( + "template constexpr T kLargeReductionTable[] = {};", + "template <> constexpr uint32_t kLargeReductionTable[] = " @ + ReductionTuble_(Float32, 70) @ ";", + "", + "template <> constexpr uint64_t kLargeReductionTable[] = " @ + ReductionTuble_(Float64, 137) @ ";", + "" +); +WriteCPPHeader("npsr::trig::data"); diff --git a/npsr/trig/extended-inl.h b/npsr/trig/extended-inl.h new file mode 100644 index 0000000..f84c791 --- /dev/null +++ b/npsr/trig/extended-inl.h @@ -0,0 +1,312 @@ +#include "hwy/base.h" +#if defined(NPSR_TRIG_EXTENDED_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_EXTENDED_INL_H_ +#undef NPSR_TRIG_EXTENDED_INL_H_ +#else +#define NPSR_TRIG_EXTENDED_INL_H_ +#endif + +#include "npsr/common.h" +#include "npsr/trig/data/data.h" + +HWY_BEFORE_NAMESPACE(); +namespace npsr::HWY_NAMESPACE::trig { + +template +HWY_API V Extended(V x) { + using namespace hn; + namespace data = ::npsr::trig::data; + using hwy::ExponentBits; + using hwy::MantissaBits; + using hwy::MantissaMask; + using hwy::SignMask; + + using D = DFromV; + using DI = RebindToSigned; + using DU = RebindToUnsigned; + using VI = Vec; + using VU = Vec; + using T = TFromV; + using TU = TFromV; + const D d; + const DI di; + const DU du; + + constexpr bool kIsSingle = std::is_same_v; + + // ============================================================================= + // PHASE 1: Table Lookup for Reduction Constants + // ============================================================================= + // Each table entry contains 3 consecutive values [high, mid, low] + // providing ~96-bit (F32) or ~192-bit (F64) precision for (4/π) × 2^exp + VU u_exponent = GetBiasedExponent(x); + VI i_table_idx = + BitCast(di, Add(ShiftLeft<1>(u_exponent), u_exponent)); // × 2 + 1 = × 3 + + // Gather three parts of (4/π) × 2^exp from precomputed table + // Generated by Python script with offset: 70 (F32) or 137 (F64) + VU u_p_hi = GatherIndex(du, data::kLargeReductionTable, i_table_idx); + VU u_p_med = GatherIndex(du, data::kLargeReductionTable + 1, i_table_idx); + VU u_p_lo = GatherIndex(du, data::kLargeReductionTable + 2, i_table_idx); + + // ============================================================================= + // PHASE 2: Extract and Normalize Mantissa + // ============================================================================= + + // Extract and Normalize Mantissa + V abx = Abs(x); + VU u_input = BitCast(du, abx); + VU u_significand = And(u_input, Set(du, MantissaMask())); + // Add implicit leading 1 bit + VU u_integer_bit = + Or(u_significand, Set(du, static_cast(1) << MantissaBits())); + VU u_mantissa = Or(u_significand, u_integer_bit); + + // Split mantissa into halves for extended precision multiplication + // F32: 16-bit halves, F64: 32-bit halves + constexpr int kHalfShift = (sizeof(T) / 2) * 8; + VU u_low_mask = Set(du, (static_cast(1) << kHalfShift) - 1); + VU u_m0 = And(u_mantissa, u_low_mask); + VU u_m1 = ShiftRight(u_mantissa); + + // Split reduction constants into halves + VU u_p0 = And(u_p_lo, u_low_mask); + VU u_p1 = ShiftRight(u_p_lo); + VU u_p2 = And(u_p_med, u_low_mask); + VU u_p3 = ShiftRight(u_p_med); + VU u_p4 = And(u_p_hi, u_low_mask); + VU u_p5 = ShiftRight(u_p_hi); + + // ============================================================================= + // PHASE 3: Extended Precision Multiplication + // ============================================================================= + // mantissa × (4/π × 2^exp) using half-word multiplications + // F32: 16×16→32 bit, F64: 32×32→64 bit multiplications + + // Products with highest precision part + VU u_m04 = Mul(u_m0, u_p4); + VU u_m05 = Mul(u_m0, u_p5); + VU u_m14 = Mul(u_m1, u_p4); + // Omit u_m1 × u_p5 to prevent overflow + + // Products with medium precision part + VU u_m02 = Mul(u_m0, u_p2); + VU u_m03 = Mul(u_m0, u_p3); + VU u_m12 = Mul(u_m1, u_p2); + VU u_m13 = Mul(u_m1, u_p3); + + // Products with lowest precision part + VU u_m01 = Mul(u_m0, u_p1); + VU u_m10 = Mul(u_m1, u_p0); + VU u_m11 = Mul(u_m1, u_p1); + + // ============================================================================= + // PHASE 4: Carry Propagation and Result Assembly + // ============================================================================= + // Extract carry bits from each product + VU u_carry04 = ShiftRight(u_m04); + VU u_carry02 = ShiftRight(u_m02); + VU u_carry03 = ShiftRight(u_m03); + VU u_carry01 = ShiftRight(u_m01); + VU u_carry10 = ShiftRight(u_m10); + + // Extract lower halves + VU u_low04 = And(u_m04, u_low_mask); + VU u_low02 = And(u_m02, u_low_mask); + VU u_low05 = And(u_m05, u_low_mask); + VU u_low03 = And(u_m03, u_low_mask); + + // Column-wise accumulation (Intel SVML pattern) + VU u_col3 = Add(u_low05, Add(u_m14, u_carry04)); + VU u_col2 = Add(u_low04, Add(u_m13, u_carry03)); + VU u_col1 = Add(u_low02, Add(u_m11, u_carry01)); + VU u_col0 = Add(u_low03, Add(u_m12, u_carry02)); + + // Carry propagation through columns + VU u_sum0 = Add(u_carry10, u_col1); + VU u_carry_final0 = ShiftRight(u_sum0); + VU u_sum1 = Add(u_carry_final0, u_col0); + VU u_carry_final1 = ShiftRight(u_sum1); + VU u_sum1_shifted = ShiftLeft(u_sum1); + VU u_sum2 = Add(u_carry_final1, u_col2); + VU u_carry_final2 = ShiftRight(u_sum2); + VU u_sum3 = Add(u_carry_final2, u_col3); + + // Assemble final result + VU u_result0 = And(u_sum0, u_low_mask); + VU u_result2 = And(u_sum2, u_low_mask); + VU u_result3 = ShiftLeft(u_sum3); + + VU u_n_hi = Add(u_result3, u_result2); + VU u_n_lo = Add(u_sum1_shifted, u_result0); + + // ============================================================================= + // PHASE 5: Extract Quotient and Fractional Parts + // ============================================================================= + + // Extract integer quotient + constexpr int kQuotientShift = + ExponentBits() + 1; // 9 for F32, 12 for F64 + VU u_shifted_n = ShiftRight(u_n_hi); + + // fractional shifts derived from magic constants + // F32: 5, 18, 14 (sum = 37, total with quotient = 46 = 2×23) + // F64: 28, 24, 40 (sum = 92, total with quotient = 104 = 2×52) + constexpr int kFracLowShift = kIsSingle ? 5 : 28; + constexpr int kFracMidShift = kIsSingle ? 18 : 24; + constexpr int kFracHighShift = kIsSingle ? 14 : 40; + + // Verify total shift constraint + constexpr int kTotalShift = + kQuotientShift + kFracLowShift + kFracMidShift + kFracHighShift; + static_assert(kTotalShift == (kIsSingle ? 46 : 104), + "Total shift must equal 2×mantissa_bits"); + + // Extract fractional parts + constexpr TU kFracMidMask = (static_cast(1) << kFracMidShift) - 1; + VU u_frac_low_bits = And(u_n_lo, Set(du, kFracMidMask)); + VU u_shifted_sig_lo = ShiftLeft(u_frac_low_bits); + VU u_frac_mid_bits = ShiftRight(u_n_lo); + constexpr TU kFracHighMask = + (static_cast(1) << (kFracHighShift - kFracLowShift)) - 1; + VU u_frac_high_bits = And(u_n_hi, Set(du, kFracHighMask)); + + // ============================================================================= + // PHASE 6: Magic Number Conversion to Floating Point + // ============================================================================= + // magic constants for branchless int→float conversion + // Handle sign bit + VU u_sign_bit = And(BitCast(du, x), Set(du, SignMask())); + VU u_exponent_part = + Xor(u_sign_bit, BitCast(du, Set(d, static_cast(1.0)))); + VU u_quotient_signed = Or(u_shifted_n, u_exponent_part); + + // Magic number conversion for quotient + V shifter = Set(d, kIsSingle ? 0x1.8p15f : 0x1.8p43); + V integer_part = Add(shifter, BitCast(d, u_quotient_signed)); + + V n_hi = Sub(integer_part, shifter); + n_hi = Sub(BitCast(d, u_quotient_signed), n_hi); + + // constants for fractional parts + VU u_epsilon = BitCast(du, Set(d, kIsSingle ? 0x1p-23f : 0x1p-52)); + VU u_exp_mid = Xor(u_sign_bit, u_epsilon); + VU u_shifted_sig_mid = + Or(ShiftLeft(u_frac_high_bits), u_frac_mid_bits); + VU u_frac_mid_combined = Or(u_shifted_sig_mid, u_exp_mid); + V shifter_mid = BitCast(d, u_exp_mid); + V n_med = Sub(BitCast(d, u_frac_mid_combined), shifter_mid); + + VU u_epsilon_low = BitCast(du, Set(d, kIsSingle ? 0x1p-46f : 0x1p-104)); + VU u_exp_low = Xor(u_sign_bit, u_epsilon_low); + VU u_frac_low_combined = Or(u_shifted_sig_lo, u_exp_low); + + V exp_low = BitCast(d, u_exp_low); + V frac_low_combined = BitCast(d, u_frac_low_combined); + + V n = Add(n_hi, n_med); + V n_lo = Sub(n_hi, n); + n_lo = Add(n_med, n_lo); + n_lo = Add(n_lo, Sub(frac_low_combined, exp_low)); + + // ============================================================================= + // PHASE 7: Convert to Radians + // ============================================================================= + + // Multiply by π with error compensation (Cody-Waite multiplication) + constexpr auto kPiMul2 = data::kPiMul2; + const V pi2_hi = Set(d, kPiMul2[0]); + const V pi2_med = Set(d, kPiMul2[1]); + + V r = Mul(pi2_hi, n); + V r_lo, r_w0, r_w1; + if constexpr (!kNativeFMA && kIsSingle) { + using DW = RepartitionToWide; + using DH = Half; + using VW = Vec; + const DW dw; + const DH dh; + VW pi2_whi = Set(dw, data::kPiMul2[0]); + VW r0 = Mul(pi2_whi, PromoteUpperTo(dw, n)); + VW r1 = Mul(pi2_whi, PromoteLowerTo(dw, n)); + VW r_lo_w0 = Sub(r0, PromoteUpperTo(dw, r)); + VW r_lo_w1 = Sub(r1, PromoteLowerTo(dw, r)); + r_lo = Combine(d, DemoteTo(dh, r_lo_w0), DemoteTo(dh, r_lo_w1)); + r_w0 = BitCast(d, r0); + r_w1 = BitCast(d, r1); + } else { + r_lo = MulSub(pi2_hi, n, r); + r_lo = MulAdd(pi2_med, n, r_lo); + } + r_lo = MulAdd(pi2_hi, n_lo, r_lo); + + // ============================================================================= + // PHASE 8: Small Argument Handling + // ============================================================================= + + const V min_input = Set(d, static_cast(0x1p-20)); + const auto ismall_arg = Gt(min_input, abx); + + r = IfThenElse(ismall_arg, x, r); + r_lo = IfThenElse(ismall_arg, Zero(d), r_lo); + V r2 = Mul(r, r); + + // ============================================================================= + // PHASE 9: Table Lookup + // ============================================================================= + + const T *table_base = + IS_COS ? data::kCosApproxTable : data::kSinApproxTable; + + // Calculate table index + VU u_n_mask = Set(du, kIsSingle ? 0xFF : 0x1FF); + VU u_index = And(BitCast(du, integer_part), u_n_mask); + VI u_table_index = BitCast(di, ShiftLeft<2>(u_index)); + + const V deriv_hi = GatherIndex(d, table_base, u_table_index); + const V sigma = GatherIndex(d, table_base + 1, u_table_index); + const V func_hi = GatherIndex(d, table_base + 2, u_table_index); + const V func_lo = GatherIndex(d, table_base + 3, u_table_index); + const V deriv = Add(deriv_hi, sigma); + + // ============================================================================= + // PHASE 10: Final Assembly + // ============================================================================= + V res_lo = NegMulAdd(func_hi, r, deriv); + res_lo = MulAdd(res_lo, r_lo, func_lo); + V res_hi_lo = MulAdd(sigma, r, func_hi); + V res_hi = MulAdd(deriv_hi, r, res_hi_lo); + + V sum_cor = MulAdd(sigma, r, Sub(func_hi, res_hi_lo)); + V deriv_hi_r_cor = MulAdd(deriv_hi, r, Sub(res_hi_lo, res_hi)); + deriv_hi_r_cor = Add(deriv_hi_r_cor, sum_cor); + res_lo = Add(res_lo, deriv_hi_r_cor); + + // Polynomial corrections + V s2 = Set(d, kIsSingle ? 0x1.1110b8p-7f : 0x1.1110fabb3551cp-7); + V s1 = Set(d, kIsSingle ? -0x1.555556p-3f : -0x1.5555555554448p-3); + V sin_poly = MulAdd(s2, r2, s1); + sin_poly = Mul(sin_poly, r); + sin_poly = Mul(sin_poly, r2); + + V c1 = Set(d, kIsSingle ? 0x1.5554f8p-5f : 0x1.5555555554ccfp-5); + const V neg_half = Set(d, static_cast(-0.5)); + V cos_poly; + if constexpr (kIsSingle) { + cos_poly = MulAdd(c1, r2, neg_half); + } else { + V c2 = Set(d, -0x1.6c16ab163b2d7p-10); + cos_poly = MulAdd(c2, r2, c1); + cos_poly = MulAdd(cos_poly, r2, neg_half); + } + cos_poly = Mul(cos_poly, r2); + + res_lo = MulAdd(sin_poly, deriv, res_lo); + res_lo = MulAdd(cos_poly, func_hi, res_lo); + return Add(res_hi, res_lo); +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace npsr::HWY_NAMESPACE::trig +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_EXTENDED_INL_H_ diff --git a/npsr/trig/high-inl.h b/npsr/trig/high-inl.h new file mode 100644 index 0000000..271e197 --- /dev/null +++ b/npsr/trig/high-inl.h @@ -0,0 +1,301 @@ +#include "npsr/common.h" +#include "npsr/trig/data/data.h" +#include "npsr/utils-inl.h" + +#if defined(NPSR_TRIG_HIGH_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_HIGH_INL_H_ +#undef NPSR_TRIG_HIGH_INL_H_ +#else +#define NPSR_TRIG_HIGH_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE::trig { + +template )> +HWY_INLINE V High(V x) { + using namespace hn; + namespace data = ::npsr::trig::data; + + using T = TFromV; + using D = DFromV; + using DU = RebindToUnsigned; + using DH = Half; + using DW = RepartitionToWide; + using VW = Vec; + + const D d; + const DU du; + const DH dh; + const DW dw; + // Load frequently used constants as vector registers + const V abs_mask = BitCast(d, Set(du, 0x7FFFFFFF)); + const V x_abs = And(abs_mask, x); + const V x_sign = AndNot(x_abs, x); + + // Transform cosine to sine using identity: cos(x) = sin(x + π/2) + const V half_pi = Set(d, data::kHalfPi); + V x_trans = x_abs; + if constexpr (IS_COS) { + x_trans = Add(x_abs, half_pi); + } + // check zero input/subnormal for cosine (cos(~0) = 1) + const auto is_cos_near_zero = Eq(x_trans, half_pi); + + // Compute N = round(input/π) + const V magic_round = Set(d, 0x1.8p23f); + V n_biased = MulAdd(x_trans, Set(d, data::kInvPi), magic_round); + V n = Sub(n_biased, magic_round); + + // Adjust quotient for cosine (accounts for π/2 phase shift) + if constexpr (IS_COS) { + // For cosine, we computed N = round((x + π/2)/π) but need N' for x: + // N = round((x + π/2)/π) = round(x/π + 0.5) + // This is often 1 more than round(x/π), so we subtract 0.5: + // N' = N - 0.5 + n = Sub(n, Set(d, 0.5f)); + } + auto WideCal = [](VW nh, VW xh_abs) -> VW { + const DFromV dw; + constexpr auto kPiPrec35 = data::kPiPrec35; + VW r = NegMulAdd(nh, Set(dw, kPiPrec35[0]), xh_abs); + r = NegMulAdd(nh, Set(dw, kPiPrec35[1]), r); + VW r2 = Mul(r, r); + + // Polynomial coefficients for sin(r) approximation on [-π/2, π/2] + const VW c9 = Set(dw, 0x1.5dbdf0e4c7deep-19); + const VW c7 = Set(dw, -0x1.9f6ffeea73463p-13); + const VW c5 = Set(dw, 0x1.110ed3804ca96p-7); + const VW c3 = Set(dw, -0x1.55554bc836587p-3); + VW poly = MulAdd(c9, r2, c7); + poly = MulAdd(r2, poly, c5); + poly = MulAdd(r2, poly, c3); + poly = Mul(poly, r2); + poly = MulAdd(r, poly, r); + return poly; + }; + + VW poly_lo = WideCal(PromoteLowerTo(dw, n), PromoteLowerTo(dw, x_abs)); + VW poly_up = WideCal(PromoteUpperTo(dw, n), PromoteUpperTo(dw, x_abs)); + + V poly = Combine(d, DemoteTo(dh, poly_up), DemoteTo(dh, poly_lo)); + // Extract octant sign information from quotient and flip the sign bit + poly = Xor(poly, + BitCast(d, ShiftLeft(BitCast(du, n_biased)))); + if constexpr (IS_COS) { + poly = IfThenElse(is_cos_near_zero, Set(d, 1.0f), poly); + } else { + // Restore original sign for sine (odd function) + poly = Xor(poly, x_sign); + } + return poly; +} +/** + * This function computes sin(x) or cos(x) for |x| < 2^24 using the Cody-Waite + * reduction algorithm combined with table lookup and polynomial approximation, + * achieves < 1 ULP error for |x| < 2^24. + * + * Algorithm Overview: + * 1. Range Reduction: Reduces input x to r where |r| < π/16 + * - Computes n = round(x * 16/π) and r = x - n*π/16 + * - Uses multi-precision arithmetic (3 parts of π/16) for accuracy + * + * 2. Table Lookup: Retrieves precomputed sin(n*π/16) and cos(n*π/16) + * - Includes high and low precision parts for cos values + * + * 3. Polynomial Approximation: Computes sin(r) and cos(r) + * - sin(r) ≈ r * (1 + r²*P_sin(r²)) where P_sin is a minimax polynomial + * - cos(r) ≈ 1 + r²*P_cos(r²) where P_cos is a minimax polynomial + * + * 4. Reconstruction: Applies angle addition formulas + * - sin(x) = sin(n*π/16 + r) = sin(n*π/16)*cos(r) + cos(n*π/16)*sin(r) + * - cos(x) = cos(n*π/16 + r) = cos(n*π/16)*cos(r) - sin(n*π/16)*sin(r) + * + */ +template )> +HWY_INLINE V High(V x) { + using namespace hn; + namespace data = ::npsr::trig::data; + + using T = TFromV; + using D = DFromV; + using DU = RebindToUnsigned; + using VU = Vec; + + const D d; + const DU du; + + // Step 1: Range reduction - find n such that x = n*(π/16) + r, where |r| < + // π/16 + V magic = Set(d, 0x1.8p52); + V n_biased = MulAdd(x, Set(d, data::k16DivPi), magic); + V n = Sub(n_biased, magic); + + // Extract integer index for table lookup (n mod 16) + VU n_int = BitCast(du, n_biased); + VU table_idx = And(n_int, Set(du, 0xF)); // Mask to get n mod 16 + + // Step 2: Load precomputed sine/cosine values for n mod 16 + V sin_hi = LutX2(data::kHiSinKPi16Table, table_idx); + V cos_hi = LutX2(data::kHiCosKPi16Table, table_idx); + // Note: cos_lo and sin_lo are packed together (32 bits each) to save memory. + // cos_lo can be used as-is since it's in the upper bits, sin_lo needs + // extraction. The precision loss is negligible for the final result. + // see lut-inl.h.py for the table generation code. + V cos_lo = LutX2(data::kPackedLowSinCosKPi16Table, table_idx); + // Extract sin_low from packed format (upper 32 bits) + V sin_lo = BitCast(d, ShiftLeft<32>(BitCast(du, cos_lo))); + + // Step 3: Multi-precision computation of remainder r + // r = x - n*(π/16)_high + constexpr auto kPiDiv16Prec29 = data::kPiDiv16Prec29; + V r_hi = NegMulAdd(n, Set(d, kPiDiv16Prec29[0]), x); + if constexpr (!kNativeFMA) { + // For F64, we need to handle the low precision part separately + r_hi = NegMulAdd(n, Set(d, kPiDiv16Prec29[3]), r_hi); + } + const V pi16_med = Set(d, kPiDiv16Prec29[1]); + const V pi16_lo = Set(d, kPiDiv16Prec29[2]); + V r_med = NegMulAdd(n, pi16_med, r_hi); + V r = NegMulAdd(n, pi16_lo, r_med); + + // Compute low precision part of r for extra accuracy + V term = NegMulAdd(pi16_med, n, Sub(r_hi, r_med)); + V r_lo = MulAdd(pi16_lo, n, Sub(r, r_med)); + r_lo = Sub(term, r_lo); + + // Step 4: Polynomial approximation + V r2 = Mul(r, r); + + // Minimax polynomial for (sin(r)/r - 1) + // sin(r)/r = 1 - r²/3! + r⁴/5! - r⁶/7! + ... + // This polynomial computes the terms after 1 + V sin_poly = Set(d, 0x1.71c97d22a73ddp-19); + sin_poly = MulAdd(sin_poly, r2, Set(d, -0x1.a01a00ed01edep-13)); + sin_poly = MulAdd(sin_poly, r2, Set(d, 0x1.111111110e99dp-7)); + sin_poly = MulAdd(sin_poly, r2, Set(d, -0x1.5555555555555p-3)); + + // Minimax polynomial for (cos(r) - 1)/r² + // cos(r) = 1 - r²/2! + r⁴/4! - r⁶/6! + ... + // This polynomial computes (cos(r) - 1)/r² + V cos_poly = Set(d, 0x1.9ffd7d9d749bcp-16); + cos_poly = MulAdd(cos_poly, r2, Set(d, -0x1.6c16c075d73f8p-10)); + cos_poly = MulAdd(cos_poly, r2, Set(d, 0x1.555555554e8d6p-5)); + cos_poly = MulAdd(cos_poly, r2, Set(d, -0x1.ffffffffffffcp-2)); + + // Step 5: Reconstruction using angle addition formulas + // + // Mathematical equivalence between traditional and SVML approaches: + // + // Traditional angle addition: + // sin(a+r) = sin(a)*cos(r) + cos(a)*sin(r) + // cos(a+r) = cos(a)*cos(r) - sin(a)*sin(r) + // + // Where for small r (|r| < π/16): + // cos(r) ≈ 1 + r²*cos_poly + // sin(r) ≈ r*(1 + sin_poly) ≈ r + r*sin_poly + // + // SVML's efficient linear approximation: + // sin(a+r) ≈ sin(a) + cos(a)*r + polynomial_corrections + // cos(a+r) ≈ cos(a) - sin(a)*r + polynomial_corrections + // + // This is mathematically equivalent but computationally more efficient: + // - Uses first-order linear terms directly: Sh + Ch*R, Ch - R*Sh + // - Applies higher-order polynomial corrections separately + // - Fewer multiplications and better numerical stability + // + // Implementation follows SVML structure: + // sin(n*π/16 + r) = sin_table + cos_table*remainder (+ corrections) + // cos(n*π/16 + r) = cos_table - sin_table*remainder (+ corrections) + V result; + if constexpr (IS_COS) { + // Cosine reconstruction: cos_table - sin_table*remainder + // Equivalent to: cos(a)*cos(r) - sin(a)*sin(r) but more efficient + V res_hi = NegMulAdd(r, sin_hi, cos_hi); // cos_hi - r*sin_hi + + // This captures the precision lost in the main computation + V r_sin_hi = Sub(cos_hi, res_hi); // Extract high part of multiplication + + // Handles rounding errors and adds sin_low contribution + V r_sin_low = MulSub(r, sin_hi, r_sin_hi); // Compute multiplication error + V sin_low_corr = MulAdd(r, sin_lo, r_sin_low); // Add sin_low term + + // This is used to apply the low-precision remainder correction + V sin_cos_r = MulAdd(r, cos_hi, sin_hi); + + // Main low precision correction: cos_low - r_low*(sin_table + cos_table*r) + // Applies the effect of the low-precision remainder on the final result + V low_corr = NegMulAdd(r_lo, sin_cos_r, cos_lo); + + // Polynomial corrections using the remainder + V r_sin = Mul(r, sin_hi); // For polynomial application + + // Apply polynomial corrections: cos_table*cos_poly - r*sin_table*sin_poly + // This handles the higher-order terms from cos(r) and sin(r) expansions + V poly_corr = Mul(cos_hi, cos_poly); // cos(a) * (cos(r)-1)/r² + // - sin(a)*r * (sin(r)/r-1) + poly_corr = NegMulAdd(r_sin, sin_poly, poly_corr); + + // Combine all low precision corrections + V total_low = Sub(low_corr, sin_low_corr); + + // Final assembly: main_term + r²*polynomial_corrections + low_corrections + result = MulAdd(r2, poly_corr, total_low); + result = Add(res_hi, result); + + } else { + // Sine reconstruction: sin_table + cos_table*remainder + // Equivalent to: sin(a)*cos(r) + cos(a)*sin(r) but more efficient + V res_hi = MulAdd(r, cos_hi, sin_hi); // sin_hi + r*cos_hi + + // This captures the precision lost in the main computation + V r_cos_hi = Sub(res_hi, sin_hi); // Extract high part of multiplication + + // Handles rounding errors and adds cos_low contribution + V r_cos_low = MulSub(r, cos_hi, r_cos_hi); // Compute multiplication error + V cos_low_corr = MulAdd(r, cos_lo, r_cos_low); // Add cos_low term + + // Intermediate term for r_low correction: cos_table - sin_table*r + // This is used to apply the low-precision remainder correction + V cos_r_sin = NegMulAdd(r, sin_hi, cos_hi); + + // Main low precision correction: sin_low - r_low*(cos_table - sin_table*r) + // Applies the effect of the low-precision remainder on the final result + V low_corr = MulAdd(r_lo, cos_r_sin, sin_lo); + // Polynomial corrections using the remainder + V r_cos = Mul(r, cos_hi); // For polynomial application + + // Apply polynomial corrections: sin_table*cos_poly + r*cos_table*sin_poly + // This handles the higher-order terms from cos(r) and sin(r) expansions + V poly_corr = Mul(sin_hi, cos_poly); // sin(a) * (cos(r)-1)/r² + poly_corr = + MulAdd(r_cos, sin_poly, poly_corr); // + cos(a)*r * (sin(r)/r-1) + + // Combine all low precision corrections + V total_low = Add(low_corr, cos_low_corr); + // Final assembly: main_term + r²*polynomial_corrections + low_corrections + result = MulAdd(r2, poly_corr, total_low); + result = Add(res_hi, result); + } + + // Apply final sign correction same for both sine and cosine + // Both functions change sign every π radians, corresponding to bit 4 of n_int + // This unified approach works because: + // - sin(x + π) = -sin(x) + // - cos(x + π) = -cos(x) + VU x_sign_int = ShiftLeft<63>(BitCast(du, x)); + // XOR with quadrant info in n_biased + VU combined = Xor(BitCast(du, n_biased), ShiftLeft<4>(x_sign_int)); + // Extract final sign + VU sign = ShiftRight<4>(combined); + sign = ShiftLeft<63>(sign); + result = Xor(result, BitCast(d, sign)); // Apply sign flip + return result; +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace npsr::HWY_NAMESPACE::trig + +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_HIGH_INL_H_ diff --git a/npsr/trig/inl.h b/npsr/trig/inl.h new file mode 100644 index 0000000..229c525 --- /dev/null +++ b/npsr/trig/inl.h @@ -0,0 +1,59 @@ +#include "npsr/common.h" +#include "npsr/trig/extended-inl.h" +#include "npsr/trig/high-inl.h" +#include "npsr/trig/low-inl.h" + +#if defined(NPSR_TRIG_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_INL_H_ +#undef NPSR_TRIG_INL_H_ +#else +#define NPSR_TRIG_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE::trig { +template +HWY_API V SinCos(Prec &prec, V x) { + using namespace hwy::HWY_NAMESPACE; + constexpr bool kIsSingle = std::is_same_v, float>; + const DFromV d; + V ret; + if constexpr (Prec::kLowAccuracy) { + ret = Low(x); + } else { + ret = High(x); + } + if constexpr (Prec::kLargeArgument) { + // Identify inputs requiring extended precision (very large arguments) + auto has_large_arg = Gt(Abs(x), Set(d, kIsSingle ? 10000.0f : 16777216.0)); + if (HWY_UNLIKELY(!AllFalse(d, has_large_arg))) { + // Use extended precision algorithm for large arguments + ret = IfThenElse(has_large_arg, Extended(x), ret); + } + } + if constexpr (Prec::kSpecialCases || Prec::kExceptions) { + auto is_finite = IsFinite(x); + ret = IfThenElse(is_finite, ret, NaN(d)); + if constexpr (Prec::kExceptions) { + prec.Raise(!AllFalse(d, IsInf(x)) ? FPExceptions::kInvalid : 0); + } + } + return ret; +} +} // namespace npsr::HWY_NAMESPACE::trig + +namespace npsr::HWY_NAMESPACE { +template +HWY_API V Sin(Prec &prec, V x) { + return trig::SinCos(prec, x); +} +template +HWY_API V Cos(Prec &prec, V x) { + return trig::SinCos(prec, x); +} +} // namespace npsr::HWY_NAMESPACE + +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_INL_H_ diff --git a/npsr/trig/low-inl.h b/npsr/trig/low-inl.h new file mode 100644 index 0000000..e8cc815 --- /dev/null +++ b/npsr/trig/low-inl.h @@ -0,0 +1,151 @@ +#include "npsr/common.h" +#include "npsr/trig/data/data.h" +#include "npsr/utils-inl.h" + +#if defined(NPSR_TRIG_LOW_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_TRIG_LOW_INL_H_ +#undef NPSR_TRIG_LOW_INL_H_ +#else +#define NPSR_TRIG_LOW_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE::trig { + +template )> +HWY_API V PolyLow(V r, V r2) { + using namespace hn; + + const DFromV d; + const V c9 = Set(d, IS_COS ? 0x1.5d866ap-19f : 0x1.5dbdfp-19f); + const V c7 = Set(d, IS_COS ? -0x1.9f6d9ep-13 : -0x1.9f6ffep-13f); + const V c5 = Set(d, IS_COS ? 0x1.110ec8p-7 : 0x1.110eccp-7f); + const V c3 = Set(d, -0x1.55554cp-3f); + V poly = MulAdd(c9, r2, c7); + poly = MulAdd(r2, poly, c5); + poly = MulAdd(r2, poly, c3); + if constexpr (IS_COS) { + // Although this path handles cosine, we have already transformed the + // input using the identity: cos(x) = sin(x + π/2) This means we're no + // longer directly evaluating a cosine Taylor series; instead, we evaluate + // the sine approximation polynomial at (x + π/2). + // + // The sine approximation has the general form: + // sin(r) ≈ r + r³ · P(r²) + // + // So, we compute: + // r³ = r · r² + // sin(r) ≈ r + r³ · poly + // + // This formulation preserves accuracy by computing the highest order + // terms last, which benefits from FMA to reduce rounding error. + V r3 = Mul(r2, r); + poly = MulAdd(r3, poly, r); + } else { + poly = Mul(poly, r2); + poly = MulAdd(r, poly, r); + } + return poly; +} + +template )> +HWY_API V PolyLow(V r, V r2) { + using namespace hn; + + const DFromV d; + const V c15 = Set(d, -0x1.9f1517e9f65fp-41); + const V c13 = Set(d, 0x1.60e6bee01d83ep-33); + const V c11 = Set(d, -0x1.ae6355aaa4a53p-26); + const V c9 = Set(d, 0x1.71de3806add1ap-19); + const V c7 = Set(d, -0x1.a01a019a659ddp-13); + const V c5 = Set(d, 0x1.111111110a573p-7); + const V c3 = Set(d, -0x1.55555555554a8p-3); + V poly = MulAdd(c15, r2, c13); + poly = MulAdd(r2, poly, c11); + poly = MulAdd(r2, poly, c9); + poly = MulAdd(r2, poly, c7); + poly = MulAdd(r2, poly, c5); + poly = MulAdd(r2, poly, c3); + return poly; +} + +template +HWY_API V Low(V x) { + using namespace hn; + using hwy::SignMask; + namespace data = ::npsr::trig::data; + + const DFromV d; + const RebindToUnsigned du; + using T = TFromV; + // Load frequently used constants as vector registers + const V abs_mask = BitCast(d, Set(du, SignMask() - 1)); + const V x_abs = And(abs_mask, x); + const V x_sign = AndNot(x_abs, x); + + constexpr bool kIsSingle = std::is_same_v; + // Transform cosine to sine using identity: cos(x) = sin(x + π/2) + const V half_pi = Set(d, data::kHalfPi); + V x_trans = x_abs; + if constexpr (IS_COS) { + x_trans = Add(x_abs, half_pi); + } + // check zero input/subnormal for cosine (cos(~0) = 1) + const auto is_cos_near_zero = Eq(x_trans, half_pi); + + // Compute N = round(x/π) using "magic number" technique + // and stores integer part in mantissa + const V magic_round = Set(d, kIsSingle ? 0x1.8p23f : 0x1.8p52); + V n_biased = MulAdd(x_trans, Set(d, data::kInvPi), magic_round); + V n = Sub(n_biased, magic_round); + + // Adjust quotient for cosine (accounts for π/2 phase shift) + if constexpr (IS_COS) { + // For cosine, we computed N = round((x + π/2)/π) but need N' for x: + // N = round((x + π/2)/π) = round(x/π + 0.5) + // This is often 1 more than round(x/π), so we subtract 0.5: + // N' = N - 0.5 + n = Sub(n, Set(d, static_cast(0.5))); + } + // Use Cody-Waite method with triple-precision PI + constexpr auto kPi = data::kPi; + + V r = NegMulAdd(n, Set(d, kPi[0]), x_abs); + r = NegMulAdd(n, Set(d, kPi[1]), r); + V r_lo = NegMulAdd(n, Set(d, kPi[2]), r); + if constexpr (!kNativeFMA) { + if (!kIsSingle) { + r = r_lo; + } + r_lo = NegMulAdd(n, Set(d, kPi[3]), r_lo); + } + + if (kIsSingle) { + r = r_lo; + } + V r2 = Mul(r, r); + V poly = PolyLow(r, r2); + + if (!kIsSingle) { + V r2_corr = Mul(r2, r_lo); + poly = MulAdd(r2_corr, poly, r_lo); + } + + // Extract octant sign information from quotient and flip the sign bit + poly = Xor(poly, + BitCast(d, ShiftLeft(BitCast(du, n_biased)))); + if constexpr (IS_COS) { + poly = IfThenElse(is_cos_near_zero, Set(d, static_cast(1.0)), poly); + } else { + // Restore original sign for sine (odd function) + poly = Xor(poly, x_sign); + } + return poly; +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace npsr::HWY_NAMESPACE::trig + +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_TRIG_LOW_INL_H_ diff --git a/npsr/utils-inl.h b/npsr/utils-inl.h new file mode 100644 index 0000000..59cbdae --- /dev/null +++ b/npsr/utils-inl.h @@ -0,0 +1,96 @@ +#include "npsr/common.h" + +// clang-format off +#if defined(NPSR_UTILS_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef NPSR_UTILS_INL_H_ +#undef NPSR_UTILS_INL_H_ +#else +#define NPSR_UTILS_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); + +namespace npsr::HWY_NAMESPACE { + +template >, typename V = VFromD> +HWY_API V LutX2(const T *lut, VU idx) { + using namespace hn; + D d; + return GatherIndex(d, lut, BitCast(RebindToSigned(), idx)); +#if 0 + D d; + if constexpr(MaxLanes(d) == sizeof(T)) { + const V lut0 = Load(d, lut); + const V lut1 = Load(d, lut + sizeof(T)); + return TwoTablesLookupLanes(d, lut0, lut1, IndicesFromVec(d, idx)); + } + else if constexpr (MaxLanes(d) == 4){ + const V lut0 = Load(d, lut); + const V lut1 = Load(d, lut + 4); + const V lut2 = Load(d, lut + sizeof(T)); + const V lut3 = Load(d, lut + 12); + + const auto high_mask = Ne(ShiftRight<3>(idx), Zero(u64)); + const auto load_mask = And(idx, Set(u64, 0b111)); + + const V lut_low = TwoTablesLookupLanes(d, lut0, lut1, IndicesFromVec(d, load_mask)); + const V lut_high = TwoTablesLookupLanes(d, lut2, lut3, IndicesFromVec(d, load_mask)); + + return IfThenElse(RebindMask(d, high_mask), lut_high, lut_low); + } + else{ + return GatherIndex(d, lut, BitCast(s64, idx)); + } +#endif +} +#if 0 +template +class Lut { +public: + using T = std::tuple_element_t<0, std::tuple>; + constexpr static size_t kSize = sizeof...(Args); + + const T *Data() const { + return array; + } + +#if 1 || HWY_MAX_BYTES == 16 + // Calculate square root if it's a perfect square + constexpr static size_t kDim = []() { + for (size_t i = 1; i * i <= kSize; ++i) { + if (i * i == kSize) return i; + } + return size_t(0); + }(); + + static_assert(kDim > 0, "Must provide a perfect square number of array"); + constexpr static size_t kRows = kDim; + constexpr static size_t kCols = kDim; + + constexpr Lut(Args... args) + : Lut(std::make_tuple(args...), std::make_index_sequence{}) { + static_assert(kDim > 0, "Must provide a perfect square number of array"); + } + +private: + template + constexpr Lut(std::tuple arg_tuple, std::index_sequence) + : array{static_cast(std::get<(Is % kRows) * kCols + (Is / kRows)>(arg_tuple))...} { + } + +#else +public: + constexpr Lut(Args... args) + : array{static_cast(args)...} { + } +#endif + +private: + HWY_ALIGN T array[kSize]; +}; + +#endif +} // namespace npsr::HWY_NAMESPACE +HWY_AFTER_NAMESPACE(); + +#endif // NPSR_UTILS_INL_H_ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f5fc80e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +build-backend = "mesonpy" +requires = [ + "meson-python>=0.15.0", +] + +[project] +name = "numpy-simd-routines-test" +version = "1.0.0.dev0" +# TODO: add `license-files` once PEP 639 is accepted (see meson-python#88) +license = {file = "LICENSE.txt"} + +description = "Fundamental package for array computing in Python" +authors = [{name = "NumPy Developers."}] +maintainers = [ + {name = "NumPy Developers", email="numpy-discussion@python.org"}, +] +requires-python = ">=3.11" +readme = "README.md" +classifiers = [ + 'Intended Audience :: Science/Research', +] + +[project.urls] +homepage = "https://numpy.org" +documentation = "https://numpy.org/doc/" +source = "https://github.com/numpy/numpy" +download = "https://pypi.org/project/numpy/#files" +tracker = "https://github.com/numpy/numpy/issues" +"release notes" = "https://numpy.org/doc/stable/release" + +#[tool.meson-python.args] +#install = ['--tags=runtime,python-runtime,tests,devel'] + +[tool.spin] +package = 'numpy_sr' + +[tool.spin.commands] +"Build" = [ + ".spin/cmds.py:generate", +] diff --git a/tools/sollya/core.sol b/tools/sollya/core.sol new file mode 100644 index 0000000..66ac25b --- /dev/null +++ b/tools/sollya/core.sol @@ -0,0 +1,311 @@ +prec = 512; +display = hexadecimal; +verbosity = 3; +showmessagenumbers = on; + +THE_OUTPUT_LINES = [||]; +THE_DISPLAY_STACK = [||]; +THE_PREC_STACK = [||]; + +Float32 = { + .kName = "float32", + .kSize = 32, + .kExpBits = 8, + .kMantBits = 23, + .kDigits = 24, + .kDigits10 = 6, + .kMaxDigits10 = 9, + .kMinExp = -126, + .kMinExp10 = -37, + .kBias = 127, + .kMaxExp10 = 38, + .kMinExpDenorm = -149, + .kMaxExpBiased = 254, + .kMin = 0x1p-126, + .kLowest = -0x1.fffffep127, + .kMax = 0x1.fffffep127, + .kEps = 0x1p-23, + .kDenormMin = 0x1p-149, + .kPyName = "float32_t", + .kCSFX = "f", + .kCName = "float", + .kCUint = "uint32_t", + .kCUintSFX = "u", + .kRound = single(x), + .kRoundStr = "single", + .kPrintDigits = "printsingle" +}; + +Float64 = { + .kName = "float64", + .kSize = 64, + .kExpBits = 11, + .kMantBits = 52, + .kDigits = 53, + .kDigits10 = 15, + .kMaxDigits10 = 17, + .kMinExp = -1022, + .kMinExp10 = -307, + .kBias = 1023, + .kMaxExp10 = 308, + .kMinExpDenorm = -1074, + .kMaxExpBiased = 2046, + .kMin = 0x1p-1022, + .kLowest = -0x1.fffffffffffffp1023, + .kMax = 0x1.fffffffffffffp1023, + .kEps = 0x1p-52, + .kDenormMin = 0x1p-1074, + .kPyName = "float64_t", + .kCSFX = "", + .kCName = "double", + .kCUint = "uint64_t", + .kCUintSFX = "ull", + .kRound = double(x), + .kRoundStr = "double", + .kPrintDigits = "printdouble" +}; + +procedure RightShift(pN, pK) { + return floor(pN / 2^pK); +}; + +procedure LeftShift(pN, pK) { + return pN * 2^pK; +}; + +procedure Join(pList, pSep) { + var r, i, v; + r = ""; + for i in pList do { + v = i @ pSep; + r = r @ v; + }; + return r; +}; + +procedure PyEval(pCode = ...) { + var code; + write(Join(pCode, "\n")) > PYTEMP_FILE_PATH; + code = bashevaluate("python3 " @ PYTEMP_FILE_PATH); + return code; +}; + +procedure SolEval(pCode = ...) { + var code; + write(Join(pCode, "\n") @ "quit;") > PYTEMP_FILE_PATH; + code = bashevaluate("sollya " @ PYTEMP_FILE_PATH); + return code; +}; + +procedure ToDigits(pT, pA) { + var i, code, prfunc, $; + code = ""; + prfunc = pT.kPrintDigits @ "("; + for i in pA do { + code = code @ (prfunc @ i @ ");"); + }; + $.hex = SolEval(code); + $.ints = PyEval( + "x = '''", + $.hex, + "'''", + "x = [str(int(l, base=0x10)) for l in x.splitlines() if l.strip()]", + "print('[|', ', '.join(x), '|];')" + ); + return parse($.ints); +}; + +procedure FromDigits(pT, pA) { + var i, code, $; + SetDisplay(decimal); + $.hex = PyEval( + "x = (", + ToStringPyArray(pA, 8), + ")", + "rstr = '" @ pT.kRoundStr @ "'", + "pad = '0" @ pT.kSize / 4 @ "x'", + "x = [f'{rstr}(0x{format(l, pad)})' for l in x]", + "print('[|', ', '.join(x), '|];')" + ); + RestoreDisplay(); + return parse($.hex); +}; + +procedure ToStringArray(pData, pSFX, pColNum) { + var r, r_final, i, j, col_widths, $; + r = [||]; + for i from 0 to length(pData) - 1 do { + $.v = pData[i] @ pSFX @ ", "; + if ($.v == "0f, ") then $.v = "0.0f, "; + r = r :. $.v; + }; + // Determine the max width for each column + col_widths = [||]; + for i from 0 to pColNum - 1 do { + col_widths = col_widths :. 0; + }; + for i from 0 to length(r) - 1 do { + $.idx = mod(i, pColNum); + if (length(r[i]) > col_widths[$.idx]) then { + col_widths[$.idx] = length(r[i]); + }; + }; + + // Create paddding string + $.pad = ""; + for i from 1 to 2 do $.pad = $.pad @ " "; + + // Build lines array + r_final = ""; + i = 0; + while (i < length(r)) do { + var chunks, chunk, line; + // Create chunk of 'col' elements + chunks = [||]; + for j from 0 to pColNum - 1 do { + if (i + j < length(r)) then { + chunks = chunks :. r[i + j]; + }; + }; + line = $.pad; + for j from 0 to length(chunks) - 1 do { + $.idx = mod(i + j, pColNum); + chunk = chunks[j]; + // Left-justify to width + while (length(chunk) < col_widths[$.idx]) do { + chunk = chunk @ " "; + }; + line = line @ chunk; + }; + r_final = r_final @ line @ "\n"; + i = i + pColNum; + }; + return r_final; +}; + +procedure ToStringCArray(pArr, pSFX, pNumCol) { + return "{\n" @ ToStringArray(pArr, pSFX, pNumCol) @ "};"; +}; + +procedure ToStringPyArray(pArr, pNumCol) { + return "[\n" @ ToStringArray(pArr, "", pNumCol) @ "]"; +}; + +procedure ConstantsFromArray(pArr) { + var r, i, j, $; + r = [||]; + $.exact = head(pArr); + $.remainder = 0; + for i in tail(pArr) do { + $.r_mod = head(i); + for j in tail(i) do { + $.val = round($.exact - $.remainder, j, $.r_mod); + $.remainder = $.remainder + $.val; + r = r :. $.val; + }; + }; + return r; +}; +procedure Constants(pArgs = ...) { + return ConstantsFromArray(pArgs); +}; + +procedure Append(pLines = ...) { + suppressmessage(56); + THE_OUTPUT_LINES = THE_OUTPUT_LINES @ pLines; + unsuppressmessage(56); +}; + +procedure SetDisplay(pMod) { + suppressmessage(56); + THE_DISPLAY_STACK = display .: THE_DISPLAY_STACK; + unsuppressmessage(56); + display = pMod; +}; + +procedure RestoreDisplay() { + Assert( + length(THE_DISPLAY_STACK) > 0, + "Display stack is empty, cannot restore display." + ); + display = head(THE_DISPLAY_STACK); + suppressmessage(56); + if (length(THE_DISPLAY_STACK) == 1) then { + THE_DISPLAY_STACK = [||]; + } else { + THE_DISPLAY_STACK = tail(THE_DISPLAY_STACK); + }; + unsuppressmessage(56); +}; + +procedure SetPrec(pPrec) { + suppressmessage(56); + THE_PREC_STACK = prec .: THE_PREC_STACK; + unsuppressmessage(56); + prec = pPrec; +}; + +procedure RestorePrec() { + Assert( + length(THE_PREC_STACK) > 0, + "Prec stack is empty, cannot restore prec." + ); + prec = head(THE_PREC_STACK); + suppressmessage(56); + if (length(THE_PREC_STACK) == 1) then { + THE_PREC_STACK = [||]; + } else { + THE_PREC_STACK = tail(THE_PREC_STACK); + }; + unsuppressmessage(56); +}; + +procedure Assert(pCondition, pMessage) { + if (!pCondition) then { + "Assertion failed: " @ pMessage; + PyEval( + "import os, signal; from pathlib import Path;", + "Path('" @OUTPUT_FILE_PATH@ "').unlink(missing_ok=True)", + "os.kill(os.getppid(), signal.SIGKILL)" + ); + }; +}; + +procedure Dump() { + var i; + for i in THE_OUTPUT_LINES do { + i; + }; + Assert(false, "Dump"); +}; + +procedure Write() { + write(Join(THE_OUTPUT_LINES, "\n")) > OUTPUT_FILE_PATH; + suppressmessage(56); + THE_OUTPUT_LINES = [||]; + unsuppressmessage(56); +}; + +procedure WriteCPPHeader(pNamespace = ...) { + var i, $; + $.pre = [| + "// Auto-generated by " @ SOURCE_FILE_PATH, + "// Use `spin generate -f` to force regeneration", + "#ifndef " @ SOURCE_GUARD_NAME, + "#define " @ SOURCE_GUARD_NAME, + "" + |]; + $.post = [||]; + for i in pNamespace do { + vNamespace = "namespace " @ i; + $.pre = $.pre :. (vNamespace @ " {"); + $.post = $.post :. ("} // " @ vNamespace); + }; + $.post = $.post @ [| + "", + "#endif // " @ SOURCE_GUARD_NAME + |]; + Append @ ($.pre @ THE_OUTPUT_LINES @ $.post); + Write(); +}; + diff --git a/tools/sollya/generate.py b/tools/sollya/generate.py new file mode 100644 index 0000000..c2f8738 --- /dev/null +++ b/tools/sollya/generate.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +import subprocess +import tempfile +import os +import pathlib +import glob +import sys +from itertools import chain + +curdir = pathlib.Path(__file__).parent +rootdir = curdir.parent.parent +sys.path.insert(0, str(curdir)) + + +def sollya(sollya_file, out, encoding="utf-8"): + print(f"Executing {sollya_file}...") + rout = str(pathlib.Path(out).resolve().relative_to(rootdir)) + rsoll = str(pathlib.Path(sollya_file).resolve().relative_to(rootdir)) + guard_name = rout.upper().replace("/", "_").replace(".", "_").replace("-", "_") + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False, encoding=encoding + ) as f: + pycode_temp = f.name + + pre = "\n".join( + [ + f'SOURCE_GUARD_NAME = "{guard_name}";', + f'SOURCE_FILE_PATH = "{rsoll}";', + f'OUTPUT_FILE_PATH = "{out}";', + f'PYTEMP_FILE_PATH = "{pycode_temp}";', + f'execute("{curdir/"core.sol"}");', + ] + ) + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".sol", delete=False, encoding=encoding + ) as f: + f.write(pre + "\n") + with open(sollya_file, "r", encoding=encoding) as rf: + f.write(rf.read().strip()) + f.write("quit;\n") + sollya_temp = f.name + + try: + # Execute Sollya with temp file + result = subprocess.run( + ["sollya", sollya_temp], cwd=pathlib.Path(sollya_file).parent + ) + if result.returncode != 0: + raise RuntimeError(f"Sollya execution failed with code {result.returncode}") + finally: + # Clean up temp file + os.unlink(sollya_temp) + os.unlink(pycode_temp) + + +def main(force): + print("Generating sollya files...") + path = rootdir / "npsr" + exts = ["*.h", "*.py", "*.csv"] + from_exts = [f"{ext}.sol" for ext in exts] + patterns = [f"{path}/**/data/{ext}" for ext in from_exts] + files = list(chain.from_iterable(glob.glob(p, recursive=True) for p in patterns)) + + for f in files: + out = f + for frm, to in zip(from_exts, exts): + out = out.replace(frm[1:], to[1:]) + if not force and pathlib.Path(out).exists(): + print(f"Skipping {out}, file already exists") + continue + sollya(f, out) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Generate C++ headers/python templates from Python scripts." + ) + parser.add_argument( + "-f", + "--force", + action="store_true", + help="Force regenerate all files, even if they already exist.", + ) + args = parser.parse_args() + main(force=args.force)