Skip to content

Commit 18373bd

Browse files
committed
Port thrust::tabulate_output_iterator to cuda
1 parent b49b50b commit 18373bd

19 files changed

+1093
-10
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#ifndef _CUDA___ITERATOR_TABULATE_OUTPUT_ITERATOR_H
12+
#define _CUDA___ITERATOR_TABULATE_OUTPUT_ITERATOR_H
13+
14+
#include <cuda/std/detail/__config>
15+
16+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
17+
# pragma GCC system_header
18+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
19+
# pragma clang system_header
20+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
21+
# pragma system_header
22+
#endif // no system header
23+
24+
#include <cuda/std/__functional/invoke.h>
25+
#include <cuda/std/__iterator/concepts.h>
26+
#include <cuda/std/__iterator/iterator_traits.h>
27+
#include <cuda/std/__ranges/movable_box.h>
28+
#include <cuda/std/__type_traits/conditional.h>
29+
#include <cuda/std/__type_traits/is_nothrow_copy_constructible.h>
30+
#include <cuda/std/__type_traits/is_nothrow_default_constructible.h>
31+
#include <cuda/std/__type_traits/is_nothrow_move_constructible.h>
32+
#include <cuda/std/__type_traits/is_same.h>
33+
#include <cuda/std/__type_traits/remove_cvref.h>
34+
#include <cuda/std/__utility/forward.h>
35+
#include <cuda/std/__utility/move.h>
36+
#include <cuda/std/cstdint>
37+
38+
#include <cuda/std/__cccl/prologue.h>
39+
40+
_LIBCUDACXX_BEGIN_NAMESPACE_CUDA
41+
42+
template <class _Fn, class _Index = _CUDA_VSTD::ptrdiff_t>
43+
class tabulate_output_iterator;
44+
45+
template <class _Fn, class _Index>
46+
class __tabulate_proxy
47+
{
48+
private:
49+
template <class, class>
50+
friend class tabulate_output_iterator;
51+
52+
_Fn& __func_;
53+
_Index __index_;
54+
55+
public:
56+
_LIBCUDACXX_HIDE_FROM_ABI constexpr explicit __tabulate_proxy(_Fn& __func, _Index __index) noexcept
57+
: __func_(__func)
58+
, __index_(__index)
59+
{}
60+
61+
_CCCL_TEMPLATE(class _Arg)
62+
_CCCL_REQUIRES(_CUDA_VSTD::is_invocable_v<const _Fn&, _Index, _Arg> _CCCL_AND(
63+
!_CUDA_VSTD::is_same_v<_CUDA_VSTD::remove_cvref_t<_Arg>, __tabulate_proxy>))
64+
_LIBCUDACXX_HIDE_FROM_ABI constexpr const __tabulate_proxy& operator=(_Arg&& __arg) const
65+
noexcept(_CUDA_VSTD::is_nothrow_invocable_v<const _Fn&, _Index, _Arg>)
66+
{
67+
_CUDA_VSTD::invoke(__func_, __index_, _CUDA_VSTD::forward<_Arg>(__arg));
68+
return *this;
69+
}
70+
71+
_CCCL_TEMPLATE(class _Arg)
72+
_CCCL_REQUIRES(_CUDA_VSTD::is_invocable_v<_Fn&, _Index, _Arg> _CCCL_AND(
73+
!_CUDA_VSTD::is_same_v<_CUDA_VSTD::remove_cvref_t<_Arg>, __tabulate_proxy>))
74+
_LIBCUDACXX_HIDE_FROM_ABI constexpr const __tabulate_proxy&
75+
operator=(_Arg&& __arg) noexcept(_CUDA_VSTD::is_nothrow_invocable_v<_Fn&, _Index, _Arg>)
76+
{
77+
_CUDA_VSTD::invoke(__func_, __index_, _CUDA_VSTD::forward<_Arg>(__arg));
78+
return *this;
79+
}
80+
};
81+
82+
//! @p tabulate_output_iterator is a special kind of output iterator which, whenever a value is assigned to a
83+
//! dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced
84+
//! iterator and the assigned value.
85+
//!
86+
//! The following code snippet demonstrated how to create a \p tabulate_output_iterator which prints the index and the
87+
//! assigned value.
88+
//!
89+
//! @code
90+
//! #include <cuda/iterator>
91+
//!
92+
//! struct print_op
93+
//! {
94+
//! __host__ __device__ void operator()(int index, float value) const
95+
//! {
96+
//! printf("%d: %f\n", index, value);
97+
//! }
98+
//! };
99+
//!
100+
//! int main()
101+
//! {
102+
//! auto tabulate_it = cuda::make_tabulate_output_iterator(print_op{});
103+
//!
104+
//! tabulate_it[0] = 1.0f; // prints: 0: 1.0
105+
//! tabulate_it[1] = 3.0f; // prints: 1: 3.0
106+
//! tabulate_it[9] = 5.0f; // prints: 9: 5.0
107+
//! }
108+
//! @endcode
109+
template <class _Fn, class _Index>
110+
class tabulate_output_iterator
111+
{
112+
private:
113+
_CCCL_NO_UNIQUE_ADDRESS _CUDA_VRANGES::__movable_box<_Fn> __func_;
114+
_Index __index_ = 0;
115+
116+
public:
117+
using iterator_concept = _CUDA_VSTD::random_access_iterator_tag;
118+
using iterator_category = _CUDA_VSTD::random_access_iterator_tag;
119+
using difference_type = _Index;
120+
using value_type = void;
121+
using pointer = void*;
122+
using reference = void;
123+
124+
#if _CCCL_HAS_CONCEPTS()
125+
_CCCL_EXEC_CHECK_DISABLE
126+
_CCCL_HIDE_FROM_ABI tabulate_output_iterator()
127+
requires _CUDA_VSTD::default_initializable<_Fn>
128+
= default;
129+
#else // ^^^ _CCCL_HAS_CONCEPTS() ^^^ / vvv !_CCCL_HAS_CONCEPTS() vvv
130+
_CCCL_EXEC_CHECK_DISABLE
131+
_CCCL_TEMPLATE(class _Fn2 = _Fn)
132+
_CCCL_REQUIRES(_CUDA_VSTD::default_initializable<_Fn2>)
133+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator() noexcept(
134+
_CUDA_VSTD::is_nothrow_default_constructible_v<_Fn2>)
135+
{}
136+
#endif // ^^^ !_CCCL_HAS_CONCEPTS() ^^^
137+
138+
//! @brief Constructs a \p tabulate_output_iterator with a given functor \p __func and index 0
139+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator(_Fn __func) noexcept(
140+
_CUDA_VSTD::is_nothrow_move_constructible_v<_Fn>)
141+
: __func_(_CUDA_VSTD::in_place, _CUDA_VSTD::move(__func))
142+
{}
143+
144+
//! @brief Constructs a \p tabulate_output_iterator with a given functor \p __func and an index
145+
_CCCL_TEMPLATE(class _Integer)
146+
_CCCL_REQUIRES(_CUDA_VSTD::__integer_like<_Integer>)
147+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator(_Fn __func, _Integer __index) noexcept(
148+
_CUDA_VSTD::is_nothrow_move_constructible_v<_Fn>)
149+
: __func_(_CUDA_VSTD::in_place, _CUDA_VSTD::move(__func))
150+
, __index_(static_cast<_Index>(__index))
151+
{}
152+
153+
//! @brief Returns the stored index
154+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr difference_type index() const noexcept
155+
{
156+
return __index_;
157+
}
158+
159+
//! @brief Dereferences the \c tabulate_output_iterator returning a proxy that applies the stored function and index
160+
//! on assignment
161+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator*() const noexcept
162+
{
163+
return __tabulate_proxy<_Fn, _Index>{const_cast<_Fn&>(*__func_), __index_};
164+
}
165+
166+
//! @brief Dereferences the \c tabulate_output_iterator returning a proxy that applies the stored function and index
167+
//! on assignment
168+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator*() noexcept
169+
{
170+
return __tabulate_proxy<_Fn, _Index>{*__func_, __index_};
171+
}
172+
173+
//! @brief Subscripts the \c tabulate_output_iterator returning a proxy that applies the stored function and advanced
174+
//! index on assignment
175+
//! @param __n The additional offset to advance the stored index
176+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator[](difference_type __n) const noexcept
177+
{
178+
return __tabulate_proxy<_Fn, _Index>{const_cast<_Fn&>(*__func_), __index_ + __n};
179+
}
180+
181+
//! @brief Subscripts the \c tabulate_output_iterator returning a proxy that applies the stored function and advanced
182+
//! index on assignment
183+
//! @param __n The additional offset to advance the stored index
184+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator[](difference_type __n) noexcept
185+
{
186+
return __tabulate_proxy<_Fn, _Index>{*__func_, __index_ + __n};
187+
}
188+
189+
//! @brief Increments the stored index
190+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator++() noexcept
191+
{
192+
++__index_;
193+
return *this;
194+
}
195+
196+
//! @brief Increments the stored index
197+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator
198+
operator++(int) noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>)
199+
{
200+
tabulate_output_iterator __tmp = *this;
201+
++__index_;
202+
return __tmp;
203+
}
204+
205+
//! @brief Decrements the stored index
206+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator--() noexcept
207+
{
208+
--__index_;
209+
return *this;
210+
}
211+
212+
//! @brief Decrements the stored index
213+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator
214+
operator--(int) noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>)
215+
{
216+
tabulate_output_iterator __tmp = *this;
217+
--__index_;
218+
return __tmp;
219+
}
220+
221+
//! @brief Returns a copy of this \c tabulate_output_iterator advanced by \p __n
222+
//! @param __n The number of elements to advance
223+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator operator+(difference_type __n) const
224+
noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>)
225+
{
226+
return tabulate_output_iterator{*__func_, __index_ + __n};
227+
}
228+
229+
//! @brief Returns a copy of a \c tabulate_output_iterator \p __iter advanced by \p __n
230+
//! @param __n The number of elements to advance
231+
//! @param __iter The original \c tabulate_output_iterator
232+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr tabulate_output_iterator
233+
operator+(difference_type __n,
234+
const tabulate_output_iterator& __iter) noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>)
235+
{
236+
return __iter + __n;
237+
}
238+
239+
//! @brief Advances the index of this \c tabulate_output_iterator by \p __n
240+
//! @param __n The number of elements to advance
241+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator+=(difference_type __n) noexcept
242+
{
243+
__index_ += __n;
244+
return *this;
245+
}
246+
247+
//! @brief Returns a copy of this \c tabulate_output_iterator decremented by \p __n
248+
//! @param __n The number of elements to decrement
249+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator operator-(difference_type __n) const
250+
noexcept(_CUDA_VSTD::is_nothrow_copy_constructible_v<_Fn>)
251+
{
252+
return tabulate_output_iterator{*__func_, __index_ - __n};
253+
}
254+
255+
//! @brief Returns the distance between two \c tabulate_output_iterator 's
256+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr difference_type
257+
operator-(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
258+
{
259+
return __rhs.__index_ - __lhs.__index_;
260+
}
261+
262+
//! @brief Decrements the index of the \c tabulate_output_iterator by \p __n
263+
//! @param __n The number of elements to decrement
264+
_LIBCUDACXX_HIDE_FROM_ABI constexpr tabulate_output_iterator& operator-=(difference_type __n) noexcept
265+
{
266+
__index_ -= __n;
267+
return *this;
268+
}
269+
270+
//! @brief Compares two \c tabulate_output_iterator for equality by comparing their indices
271+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool
272+
operator==(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
273+
{
274+
return __lhs.__index_ == __rhs.__index_;
275+
}
276+
277+
#if _CCCL_STD_VER <= 2017
278+
//! @brief Compares two \c tabulate_output_iterator for inequality by comparing their indices
279+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool
280+
operator!=(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
281+
{
282+
return __lhs.__index_ != __rhs.__index_;
283+
}
284+
#endif // _CCCL_STD_VER <= 2017
285+
286+
#if _LIBCUDACXX_HAS_SPACESHIP_OPERATOR()
287+
//! @brief Three-way-compares two \c tabulate_output_iterator by comparing their indices
288+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr strong_ordering
289+
operator<=>(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
290+
{
291+
return __lhs.__index_ <=> __rhs.__index_;
292+
}
293+
#endif // _LIBCUDACXX_HAS_SPACESHIP_OPERATOR()
294+
295+
//! @brief Compares two \c tabulate_output_iterator for less than by comparing their indices
296+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool
297+
operator<(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
298+
{
299+
return __lhs.__index_ < __rhs.__index_;
300+
}
301+
302+
//! @brief Compares two \c tabulate_output_iterator for less equal by comparing their indices
303+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool
304+
operator<=(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
305+
{
306+
return __lhs.__index_ <= __rhs.__index_;
307+
}
308+
309+
//! @brief Compares two \c tabulate_output_iterator for greater than by comparing their indices
310+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool
311+
operator>(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
312+
{
313+
return __lhs.__index_ > __rhs.__index_;
314+
}
315+
316+
//! @brief Compares two \c tabulate_output_iterator for greater equal by comparing their indices
317+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI friend constexpr bool
318+
operator>=(const tabulate_output_iterator& __lhs, const tabulate_output_iterator& __rhs) noexcept
319+
{
320+
return __lhs.__index_ >= __rhs.__index_;
321+
}
322+
};
323+
324+
template <class _Fn>
325+
_CCCL_HOST_DEVICE tabulate_output_iterator(_Fn) -> tabulate_output_iterator<_Fn, _CUDA_VSTD::ptrdiff_t>;
326+
327+
template <class _Fn, class _Index>
328+
_CCCL_HOST_DEVICE tabulate_output_iterator(_Fn, _Index) -> tabulate_output_iterator<_Fn, _Index>;
329+
330+
//! @brief Creates a \p tabulate_output_iterator from an optional index.
331+
//! @param __index The index of the \p tabulate_output_iterator within a range. The default index is \c 0.
332+
//! @return A new \p tabulate_output_iterator with \p __index as the couner.
333+
_CCCL_TEMPLATE(class _Fn, class _Integer = _CUDA_VSTD::ptrdiff_t)
334+
_CCCL_REQUIRES(_CUDA_VSTD::__integer_like<_Integer>)
335+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr auto make_tabulate_output_iterator(_Fn __func, _Integer __index = 0)
336+
{
337+
return tabulate_output_iterator{_CUDA_VSTD::move(__func), __index};
338+
}
339+
340+
_LIBCUDACXX_END_NAMESPACE_CUDA
341+
342+
#include <cuda/std/__cccl/epilogue.h>
343+
344+
#endif // _CUDA___ITERATOR_TABULATE_OUTPUT_ITERATOR_H

libcudacxx/include/cuda/iterator

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <cuda/__iterator/counting_iterator.h>
2525
#include <cuda/__iterator/discard_iterator.h>
2626
#include <cuda/__iterator/strided_iterator.h>
27+
#include <cuda/__iterator/tabulate_output_iterator.h>
2728
#include <cuda/__iterator/transform_iterator.h>
2829
#include <cuda/std/iterator>
2930

0 commit comments

Comments
 (0)