-
Notifications
You must be signed in to change notification settings - Fork 44
Honor output range distribution in dash::transform #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Changes from all commits
8d61420
1b431e0
b0b07b3
7fa405e
8325b0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
#include <dash/GlobRef.h> | ||
#include <dash/GlobAsyncRef.h> | ||
|
||
#include <dash/algorithm/Copy.h> | ||
#include <dash/algorithm/LocalRange.h> | ||
#include <dash/algorithm/Operation.h> | ||
#include <dash/algorithm/Accumulate.h> | ||
|
@@ -48,8 +49,33 @@ inline dart_ret_t transform_blocking_impl( | |
return result; | ||
} | ||
|
||
/** | ||
* Wrapper of the non-blocking DART accumulate operation with local completion. | ||
* Allows re-use of \c values pointer after the call returns. | ||
*/ | ||
template< typename ValueType > | ||
dart_ret_t transform_local_blocking_impl( | ||
dart_gptr_t dest, | ||
ValueType * values, | ||
size_t nvalues, | ||
dart_operation_t op) | ||
{ | ||
static_assert(dash::dart_datatype<ValueType>::value != DART_TYPE_UNDEFINED, | ||
"Cannot accumulate unknown type!"); | ||
|
||
dart_ret_t result = dart_accumulate( | ||
dest, | ||
reinterpret_cast<void *>(values), | ||
nvalues, | ||
dash::dart_datatype<ValueType>::value, | ||
op); | ||
dart_flush_local(dest); | ||
return result; | ||
} | ||
|
||
/** | ||
* Wrapper of the non-blocking DART accumulate operation. | ||
* The pointer \c values should not be re-used before the operation completed. | ||
*/ | ||
template< typename ValueType > | ||
dart_ret_t transform_impl( | ||
|
@@ -67,7 +93,6 @@ dart_ret_t transform_impl( | |
nvalues, | ||
dash::dart_datatype<ValueType>::value, | ||
op); | ||
dart_flush_local(dest); | ||
return result; | ||
} | ||
|
||
|
@@ -271,40 +296,65 @@ GlobOutputIt transform( | |
BinaryOperation binary_op) | ||
{ | ||
DASH_LOG_DEBUG("dash::transform(af, al, bf, outf, binop)"); | ||
auto &pattern = out_first.pattern(); | ||
// Outut range different from rhs input range is not supported yet | ||
auto in_first = in_a_first; | ||
auto in_last = in_a_last; | ||
std::vector<ValueType> in_range; | ||
ValueType* in_first = &(*in_a_first); | ||
ValueType* in_last = &(*in_a_last); | ||
// Number of elements in local range: | ||
size_t num_local_elements = std::distance(in_first, in_last); | ||
auto out_last = out_first + num_local_elements; | ||
if (out_last.gpos() > pattern.size()) { | ||
DASH_THROW(dash::exception::OutOfRange, | ||
"Too many input elements in dash::transform"); | ||
} | ||
if (in_b_first == out_first) { | ||
// Output range is rhs input range: C += A | ||
// Input is (in_a_first, in_a_last). | ||
} else { | ||
// Output range different from rhs input range: C = A+B | ||
// Input is (in_a_first, in_a_last) + (in_b_first, in_b_last): | ||
std::transform( | ||
in_a_first, in_a_last, | ||
dash::copy( | ||
in_b_first, | ||
std::back_inserter(in_range), | ||
binary_op); | ||
in_first = in_range.data(); | ||
in_last = in_first + in_range.size(); | ||
in_b_first + std::distance(in_a_first, in_a_last), | ||
out_first); | ||
} | ||
|
||
dash::util::Trace trace("transform"); | ||
|
||
// Resolve local range from global range: | ||
// Number of elements in local range: | ||
size_t num_local_elements = std::distance(in_first, in_last); | ||
// Global iterator to dart_gptr_t: | ||
dart_gptr_t dest_gptr = out_first.dart_gptr(); | ||
// Send accumulate message: | ||
trace.enter_state("transform_blocking"); | ||
dash::internal::transform_blocking_impl( | ||
auto &team = pattern.team(); | ||
size_t towrite = num_local_elements; | ||
auto out_it = out_first; | ||
auto in_it = in_first; | ||
while (towrite > 0) { | ||
auto lpos = out_it.lpos(); | ||
size_t lsize = pattern.local_size(lpos.unit); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is what I already mentioned multiple times. This approach only works for a single continuous index range. Hence it would be better to use the new range based views interface by @fuchsto . There should be something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fmoessbauer @bertwesarg The term |
||
size_t num_values = std::min(lsize - lpos.index, towrite); | ||
dart_gptr_t dest_gptr = out_it.dart_gptr(); | ||
// use non-blocking transform and wait for all at the end | ||
dash::internal::transform_impl( | ||
dest_gptr, | ||
in_first, | ||
num_local_elements, | ||
in_it, | ||
num_values, | ||
binary_op.dart_operation()); | ||
trace.exit_state("transform_blocking"); | ||
out_it += num_values; | ||
in_it += num_values; | ||
towrite -= num_values; | ||
} | ||
|
||
// out_first.team().barrier(); | ||
dart_flush_all(out_first.dart_gptr()); | ||
|
||
|
||
// trace.enter_state("transform_blocking"); | ||
// dash::internal::transform_blocking_impl( | ||
// dest_gptr, | ||
// in_first, | ||
// num_local_elements, | ||
// binary_op.dart_operation()); | ||
// trace.exit_state("transform_blocking"); | ||
// The position past the last element transformed in global element space | ||
// cannot be resolved from the size of the local range if the local range | ||
// spans over more than one block. Otherwise, the difference of two global | ||
|
@@ -320,7 +370,7 @@ GlobOutputIt transform( | |
// For ranges over block borders, we would have to resolve the global | ||
// position past the last element transformed from the iterator's pattern | ||
// (see dash::PatternIterator). | ||
return out_first + num_local_elements; | ||
return out_it; | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
|
||
#include <dash/algorithm/Transform.h> | ||
#include <dash/algorithm/Generate.h> | ||
#include <dash/algorithm/Fill.h> | ||
#include <dash/algorithm/ForEach.h> | ||
|
||
#include <dash/Array.h> | ||
#include <dash/Matrix.h> | ||
|
@@ -221,3 +223,41 @@ TEST_F(TransformTest, MatrixGlobalPlusGlobalBlocking) | |
EXPECT_EQ_U(first_l_block_a_begin, | ||
first_l_block_a_offsets); | ||
} | ||
|
||
|
||
TEST_F(TransformTest, LocalIteratorInput) | ||
{ | ||
using value_t = int; | ||
std::vector<value_t> local_v(100); | ||
size_t idx = 0; | ||
std::fill(local_v.begin(), local_v.end(), (value_t)dash::myid()); | ||
for (auto& elem : local_v) { | ||
elem = dash::myid() * 1000 + idx; | ||
idx++; | ||
} | ||
dash::Array<value_t> global_v(local_v.size() + 1); | ||
dash::fill(global_v.begin(), global_v.end(), 0.0); | ||
global_v.barrier(); | ||
// start from the second element | ||
auto it = dash::transform<value_t>( | ||
local_v.begin(), | ||
local_v.end(), | ||
global_v.begin() + 1, | ||
global_v.begin() + 1, | ||
dash::max<value_t>() | ||
); | ||
|
||
global_v.barrier(); | ||
|
||
ASSERT_EQ_U(it, global_v.end()); | ||
|
||
// size_t idx = 0; | ||
|
||
dash::for_each_with_index(global_v.begin() + 1, global_v.end(), | ||
[](value_t val, size_t idx){ | ||
ASSERT_EQ_U(val, (dash::size() - 1) * 1000 + (idx - 1)); | ||
++idx; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove |
||
}); | ||
|
||
global_v.barrier(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused?