|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// Part of libcu++, the C++ Standard Library for your entire system, |
| 4 | +// under the Apache License v2.0 with LLVM Exceptions. |
| 5 | +// See https://llvm.org/LICENSE.txt for license information. |
| 6 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. |
| 8 | +// |
| 9 | +//===----------------------------------------------------------------------===// |
| 10 | + |
| 11 | +#ifndef _CUDA___ITERATOR_TABULATE_OUTPUT_ITERATOR_H |
| 12 | +#define _CUDA___ITERATOR_TABULATE_OUTPUT_ITERATOR_H |
| 13 | + |
| 14 | +#include <cuda/std/detail/__config> |
| 15 | + |
| 16 | +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) |
| 17 | +# pragma GCC system_header |
| 18 | +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) |
| 19 | +# pragma clang system_header |
| 20 | +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) |
| 21 | +# pragma system_header |
| 22 | +#endif // no system header |
| 23 | + |
| 24 | +#include <cuda/std/__functional/invoke.h> |
| 25 | +#include <cuda/std/__iterator/concepts.h> |
| 26 | +#include <cuda/std/__iterator/iterator_traits.h> |
| 27 | +#include <cuda/std/__ranges/movable_box.h> |
| 28 | +#include <cuda/std/__type_traits/conditional.h> |
| 29 | +#include <cuda/std/__type_traits/is_nothrow_copy_constructible.h> |
| 30 | +#include <cuda/std/__type_traits/is_nothrow_default_constructible.h> |
| 31 | +#include <cuda/std/__type_traits/is_nothrow_move_constructible.h> |
| 32 | +#include <cuda/std/__type_traits/is_same.h> |
| 33 | +#include <cuda/std/__type_traits/remove_cvref.h> |
| 34 | +#include <cuda/std/__utility/forward.h> |
| 35 | +#include <cuda/std/__utility/move.h> |
| 36 | +#include <cuda/std/cstdint> |
| 37 | + |
| 38 | +#include <cuda/std/__cccl/prologue.h> |
| 39 | + |
| 40 | +_LIBCUDACXX_BEGIN_NAMESPACE_CUDA |
| 41 | + |
| 42 | +template <class _Fn, class _Index = _CUDA_VSTD::ptrdiff_t> |
| 43 | +class tabulate_output_iterator; |
| 44 | + |
| 45 | +template <class _Fn, class _Index> |
| 46 | +class __tabulate_proxy |
| 47 | +{ |
| 48 | +private: |
| 49 | + template <class, class> |
| 50 | + friend class tabulate_output_iterator; |
| 51 | + |
| 52 | + _Fn& __func_; |
| 53 | + _Index __index_; |
| 54 | + |
| 55 | +public: |
| 56 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr explicit __tabulate_proxy(_Fn& __func, _Index __index) noexcept |
| 57 | + : __func_(__func) |
| 58 | + , __index_(__index) |
| 59 | + {} |
| 60 | + |
| 61 | + _CCCL_TEMPLATE(class _Arg) |
| 62 | + _CCCL_REQUIRES(_CUDA_VSTD::is_invocable_v<const _Fn&, _Index, _Arg> _CCCL_AND( |
| 63 | + !_CUDA_VSTD::is_same_v<_CUDA_VSTD::remove_cvref_t<_Arg>, __tabulate_proxy>)) |
| 64 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr const __tabulate_proxy& operator=(_Arg&& __arg) const |
| 65 | + noexcept(_CUDA_VSTD::is_nothrow_invocable_v<const _Fn&, _Index, _Arg>) |
| 66 | + { |
| 67 | + _CUDA_VSTD::invoke(__func_, __index_, _CUDA_VSTD::forward<_Arg>(__arg)); |
| 68 | + return *this; |
| 69 | + } |
| 70 | + |
| 71 | + _CCCL_TEMPLATE(class _Arg) |
| 72 | + _CCCL_REQUIRES(_CUDA_VSTD::is_invocable_v<_Fn&, _Index, _Arg> _CCCL_AND( |
| 73 | + !_CUDA_VSTD::is_same_v<_CUDA_VSTD::remove_cvref_t<_Arg>, __tabulate_proxy>)) |
| 74 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr const __tabulate_proxy& |
| 75 | + operator=(_Arg&& __arg) noexcept(_CUDA_VSTD::is_nothrow_invocable_v<_Fn&, _Index, _Arg>) |
| 76 | + { |
| 77 | + _CUDA_VSTD::invoke(__func_, __index_, _CUDA_VSTD::forward<_Arg>(__arg)); |
| 78 | + return *this; |
| 79 | + } |
| 80 | +}; |
| 81 | + |
| 82 | +//! @p tabulate_output_iterator is a special kind of output iterator which, whenever a value is assigned to a |
| 83 | +//! dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced |
| 84 | +//! iterator and the assigned value. |
| 85 | +//! |
| 86 | +//! The following code snippet demonstrated how to create a \p tabulate_output_iterator which prints the index and the |
| 87 | +//! assigned value. |
| 88 | +//! |
| 89 | +//! @code |
| 90 | +//! #include <cuda/iterator> |
| 91 | +//! |
| 92 | +//! struct print_op |
| 93 | +//! { |
| 94 | +//! __host__ __device__ void operator()(int index, float value) const |
| 95 | +//! { |
| 96 | +//! printf("%d: %f\n", index, value); |
| 97 | +//! } |
| 98 | +//! }; |
| 99 | +//! |
| 100 | +//! int main() |
| 101 | +//! { |
| 102 | +//! auto tabulate_it = cuda::make_tabulate_output_iterator(print_op{}); |
| 103 | +//! |
| 104 | +//! tabulate_it[0] = 1.0f; // prints: 0: 1.0 |
| 105 | +//! tabulate_it[1] = 3.0f; // prints: 1: 3.0 |
| 106 | +//! tabulate_it[9] = 5.0f; // prints: 9: 5.0 |
| 107 | +//! } |
| 108 | +//! @endcode |
| 109 | +template <class _Fn, class _Index> |
| 110 | +class tabulate_output_iterator |
| 111 | +{ |
| 112 | +private: |
| 113 | + _CCCL_NO_UNIQUE_ADDRESS _CUDA_VRANGES::__movable_box<_Fn> __func_; |
| 114 | + _Index __index_ = 0; |
| 115 | + |
| 116 | +public: |
| 117 | + using iterator_concept = _CUDA_VSTD::random_access_iterator_tag; |
| 118 | + using iterator_category = _CUDA_VSTD::random_access_iterator_tag; |
| 119 | + using difference_type = _Index; |
| 120 | + using value_type = void; |
| 121 | + using pointer = void*; |
| 122 | + using reference = void; |
| 123 | + |
| 124 | +#if _CCCL_HAS_CONCEPTS() |
| 125 | + _CCCL_EXEC_CHECK_DISABLE |
| 126 | + _CCCL_HIDE_FROM_ABI tabulate_output_iterator() |
| 127 | + requires _CUDA_VSTD::default_initializable<_Fn> |
| 128 | + = default; |
| 129 | +#else // ^^^ _CCCL_HAS_CONCEPTS() ^^^ / vvv !_CCCL_HAS_CONCEPTS() vvv |
| 130 | + _CCCL_EXEC_CHECK_DISABLE |
| 131 | + _CCCL_TEMPLATE(class _Fn2 = _Fn) |
| 132 | + _CCCL_REQUIRES(_CUDA_VSTD::default_initializable<_Fn2>) |
| 133 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator() noexcept( |
| 134 | + _CUDA_VSTD::is_nothrow_default_constructible_v<_Fn2>) |
| 135 | + {} |
| 136 | +#endif // ^^^ !_CCCL_HAS_CONCEPTS() ^^^ |
| 137 | + |
| 138 | + //! @brief Constructs a \p tabulate_output_iterator with a given functor \p __func and index 0 |
| 139 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator(_Fn __func) noexcept( |
| 140 | + _CUDA_VSTD::is_nothrow_move_constructible_v<_Fn>) |
| 141 | + : __func_(_CUDA_VSTD::in_place, _CUDA_VSTD::move(__func)) |
| 142 | + {} |
| 143 | + |
| 144 | + //! @brief Constructs a \p tabulate_output_iterator with a given functor \p __func and an index |
| 145 | + _CCCL_TEMPLATE(class _Integer) |
| 146 | + _CCCL_REQUIRES(_CUDA_VSTD::__integer_like<_Integer>) |
| 147 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator(_Fn __func, _Integer __index) noexcept( |
| 148 | + _CUDA_VSTD::is_nothrow_move_constructible_v<_Fn>) |
| 149 | + : __func_(_CUDA_VSTD::in_place, _CUDA_VSTD::move(__func)) |
| 150 | + , __index_(static_cast<_Index>(__index)) |
| 151 | + {} |
| 152 | + |
| 153 | + //! @brief Returns the stored index |
| 154 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr difference_type index() const noexcept |
| 155 | + { |
| 156 | + return __index_; |
| 157 | + } |
| 158 | + |
| 159 | + //! @brief Dereferences the \c tabulate_output_iterator returning a proxy that applies the stored function and index |
| 160 | + //! on assignment |
| 161 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator*() const noexcept |
| 162 | + { |
| 163 | + return __tabulate_proxy<_Fn, _Index>{const_cast<_Fn&>(*__func_), __index_}; |
| 164 | + } |
| 165 | + |
| 166 | + //! @brief Dereferences the \c tabulate_output_iterator returning a proxy that applies the stored function and index |
| 167 | + //! on assignment |
| 168 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator*() noexcept |
| 169 | + { |
| 170 | + return __tabulate_proxy<_Fn, _Index>{*__func_, __index_}; |
| 171 | + } |
| 172 | + |
| 173 | + //! @brief Subscripts the \c tabulate_output_iterator returning a proxy that applies the stored function and advanced |
| 174 | + //! index on assignment |
| 175 | + //! @param __n The additional offset to advance the stored index |
| 176 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator[](difference_type __n) const noexcept |
| 177 | + { |
| 178 | + return __tabulate_proxy<_Fn, _Index>{const_cast<_Fn&>(*__func_), __index_ + __n}; |
| 179 | + } |
| 180 | + |
| 181 | + //! @brief Subscripts the \c tabulate_output_iterator returning a proxy that applies the stored function and advanced |
| 182 | + //! index on assignment |
| 183 | + //! @param __n The additional offset to advance the stored index |
| 184 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator[](difference_type __n) noexcept |
| 185 | + { |
| 186 | + return __tabulate_proxy<_Fn, _Index>{*__func_, __index_ + __n}; |
| 187 | + } |
| 188 | + |
| 189 | + //! @brief Increments the stored index |
| 190 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator++() noexcept |
| 191 | + { |
| 192 | + ++__index_; |
| 193 | + return *this; |
| 194 | + } |
| 195 | + |
| 196 | + //! @brief Increments the stored index |
| 197 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator |
| 198 | + operator++(int) noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>) |
| 199 | + { |
| 200 | + tabulate_output_iterator __tmp = *this; |
| 201 | + ++__index_; |
| 202 | + return __tmp; |
| 203 | + } |
| 204 | + |
| 205 | + //! @brief Decrements the stored index |
| 206 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator--() noexcept |
| 207 | + { |
| 208 | + --__index_; |
| 209 | + return *this; |
| 210 | + } |
| 211 | + |
| 212 | + //! @brief Decrements the stored index |
| 213 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator |
| 214 | + operator--(int) noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>) |
| 215 | + { |
| 216 | + tabulate_output_iterator __tmp = *this; |
| 217 | + --__index_; |
| 218 | + return __tmp; |
| 219 | + } |
| 220 | + |
| 221 | + //! @brief Returns a copy of this \c tabulate_output_iterator advanced by \p __n |
| 222 | + //! @param __n The number of elements to advance |
| 223 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator operator+(difference_type __n) const |
| 224 | + noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>) |
| 225 | + { |
| 226 | + return tabulate_output_iterator{*__func_, __index_ + __n}; |
| 227 | + } |
| 228 | + |
| 229 | + //! @brief Returns a copy of a \c tabulate_output_iterator \p __iter advanced by \p __n |
| 230 | + //! @param __n The number of elements to advance |
| 231 | + //! @param __iter The original \c tabulate_output_iterator |
| 232 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr tabulate_output_iterator |
| 233 | + operator+(difference_type __n, |
| 234 | + const tabulate_output_iterator& __iter) noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>) |
| 235 | + { |
| 236 | + return __iter + __n; |
| 237 | + } |
| 238 | + |
| 239 | + //! @brief Advances the index of this \c tabulate_output_iterator by \p __n |
| 240 | + //! @param __n The number of elements to advance |
| 241 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator+=(difference_type __n) noexcept |
| 242 | + { |
| 243 | + __index_ += __n; |
| 244 | + return *this; |
| 245 | + } |
| 246 | + |
| 247 | + //! @brief Returns a copy of this \c tabulate_output_iterator decremented by \p __n |
| 248 | + //! @param __n The number of elements to decrement |
| 249 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator operator-(difference_type __n) const |
| 250 | + noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>) |
| 251 | + { |
| 252 | + return tabulate_output_iterator{*__func_, __index_ - __n}; |
| 253 | + } |
| 254 | + |
| 255 | + //! @brief Returns the distance between two \c tabulate_output_iterator 's |
| 256 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr difference_type |
| 257 | + operator-(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 258 | + { |
| 259 | + return __rhs.__index_ - __lhs.__index_; |
| 260 | + } |
| 261 | + |
| 262 | + //! @brief Decrements the index of the \c tabulate_output_iterator by \p __n |
| 263 | + //! @param __n The number of elements to decrement |
| 264 | + _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator-=(difference_type __n) noexcept |
| 265 | + { |
| 266 | + __index_ -= __n; |
| 267 | + return *this; |
| 268 | + } |
| 269 | + |
| 270 | + //! @brief Compares two \c tabulate_output_iterator for equality by comparing their indices |
| 271 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool |
| 272 | + operator==(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 273 | + { |
| 274 | + return __lhs.__index_ == __rhs.__index_; |
| 275 | + } |
| 276 | + |
| 277 | +#if _CCCL_STD_VER <= 2017 |
| 278 | + //! @brief Compares two \c tabulate_output_iterator for inequality by comparing their indices |
| 279 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool |
| 280 | + operator!=(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 281 | + { |
| 282 | + return __lhs.__index_ != __rhs.__index_; |
| 283 | + } |
| 284 | +#endif // _CCCL_STD_VER <= 2017 |
| 285 | + |
| 286 | +#if _LIBCUDACXX_HAS_SPACESHIP_OPERATOR() |
| 287 | + //! @brief Three-way-compares two \c tabulate_output_iterator by comparing their indices |
| 288 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr strong_ordering |
| 289 | + operator<=>(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 290 | + { |
| 291 | + return __lhs.__index_ <=> __rhs.__index_; |
| 292 | + } |
| 293 | +#endif // _LIBCUDACXX_HAS_SPACESHIP_OPERATOR() |
| 294 | + |
| 295 | + //! @brief Compares two \c tabulate_output_iterator for less than by comparing their indices |
| 296 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool |
| 297 | + operator<(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 298 | + { |
| 299 | + return __lhs.__index_ < __rhs.__index_; |
| 300 | + } |
| 301 | + |
| 302 | + //! @brief Compares two \c tabulate_output_iterator for less equal by comparing their indices |
| 303 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool |
| 304 | + operator<=(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 305 | + { |
| 306 | + return __lhs.__index_ <= __rhs.__index_; |
| 307 | + } |
| 308 | + |
| 309 | + //! @brief Compares two \c tabulate_output_iterator for greater than by comparing their indices |
| 310 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool |
| 311 | + operator>(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 312 | + { |
| 313 | + return __lhs.__index_ > __rhs.__index_; |
| 314 | + } |
| 315 | + |
| 316 | + //! @brief Compares two \c tabulate_output_iterator for greater equal by comparing their indices |
| 317 | + [[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool |
| 318 | + operator>=(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept |
| 319 | + { |
| 320 | + return __lhs.__index_ >= __rhs.__index_; |
| 321 | + } |
| 322 | +}; |
| 323 | + |
| 324 | +template <class _Fn> |
| 325 | +_CCCL_HOST_DEVICE tabulate_output_iterator(_Fn) -> tabulate_output_iterator<_Fn, _CUDA_VSTD::ptrdiff_t>; |
| 326 | + |
| 327 | +template <class _Fn, class _Index> |
| 328 | +_CCCL_HOST_DEVICE tabulate_output_iterator(_Fn, _Index) -> tabulate_output_iterator<_Fn, _Index>; |
| 329 | + |
| 330 | +//! @brief Creates a \p tabulate_output_iterator from an optional index. |
| 331 | +//! @param __index The index of the \p tabulate_output_iterator within a range. The default index is \c 0. |
| 332 | +//! @return A new \p tabulate_output_iterator with \p __index as the couner. |
| 333 | +_CCCL_TEMPLATE(class _Fn, class _Integer = _CUDA_VSTD::ptrdiff_t) |
| 334 | +_CCCL_REQUIRES(_CUDA_VSTD::__integer_like<_Integer>) |
| 335 | +[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto make_tabulate_output_iterator(_Fn __func, _Integer __index = 0) |
| 336 | +{ |
| 337 | + return tabulate_output_iterator{_CUDA_VSTD::move(__func), __index}; |
| 338 | +} |
| 339 | + |
| 340 | +_LIBCUDACXX_END_NAMESPACE_CUDA |
| 341 | + |
| 342 | +#include <cuda/std/__cccl/epilogue.h> |
| 343 | + |
| 344 | +#endif // _CUDA___ITERATOR_TABULATE_OUTPUT_ITERATOR_H |
0 commit comments