Skip to content

Commit

Permalink
messages
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszpn committed Jan 25, 2024
1 parent 3b16615 commit faf565c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
25 changes: 23 additions & 2 deletions include/dr/mhp/algorithms/sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ template <typename valT, typename Compare, typename Seg>
void splitters(Seg &lsegment, Compare &&comp,
std::vector<std::size_t> &vec_split_i,
std::vector<std::size_t> &vec_split_s) {
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);

const std::size_t _comm_size = default_comm().size(); // dr-style ignore

Expand All @@ -157,12 +158,16 @@ void splitters(Seg &lsegment, Compare &&comp,
* each segment into equal parts */
if (mhp::use_sycl()) {
#ifdef SYCL_LANGUAGE_VERSION
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);

for (std::size_t _i = 0; _i < rng::size(vec_lmedians) - 1; _i++) {
assert(_i * _step_m < rng::size(lsegment));
sycl_copy<valT>(&lsegment[_i * _step_m], &vec_lmedians[_i]);
}
sycl_copy<valT>(&lsegment[rng::size(lsegment) - 1],
&vec_lmedians[rng::size(vec_lmedians) - 1]);

fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
#else
assert(false);
#endif
Expand All @@ -186,8 +191,6 @@ void splitters(Seg &lsegment, Compare &&comp,

std::size_t segidx = 0, vidx = 1;

// auto begin = std::chrono::high_resolution_clock::now();

/* TODO: copy and loop below takes most of time of the whole sort procedure;
* move it to the SYCL kernel */
if (mhp::use_sycl()) {
Expand All @@ -196,6 +199,8 @@ void splitters(Seg &lsegment, Compare &&comp,
sycl_copy(rng::data(lsegment), rng::data(vec_lseg_tmp),
rng::size(lsegment));

fmt::print("{}:{}\n", default_comm().rank(), __LINE__);

while (vidx < _comm_size && segidx < rng::size(lsegment)) {
if (comp(vec_split_v[vidx - 1], vec_lseg_tmp[segidx])) {
vec_split_i[vidx] = segidx;
Expand All @@ -205,10 +210,14 @@ void splitters(Seg &lsegment, Compare &&comp,
segidx++;
}
}
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);

#else
assert(false);
#endif
} else {
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);

while (vidx < _comm_size && segidx < rng::size(lsegment)) {
if (comp(vec_split_v[vidx - 1], lsegment[segidx])) {
vec_split_i[vidx] = segidx;
Expand All @@ -218,13 +227,15 @@ void splitters(Seg &lsegment, Compare &&comp,
segidx++;
}
}
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
}
assert(rng::size(lsegment) > vec_split_i[vidx - 1]);
vec_split_s[vidx - 1] = rng::size(lsegment) - vec_split_i[vidx - 1];

// auto end = std::chrono::high_resolution_clock::now();
// fmt::print("{}: splitters 3 duration {} ms\n", default_comm().rank(),
// std::chrono::duration<float>(end - begin).count() * 1000);
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
}

template <typename valT>
Expand Down Expand Up @@ -367,13 +378,17 @@ void dist_sort(R &r, Compare &&comp) {
std::vector<std::size_t> vec_recv_elems(_comm_size, 0);
std::size_t _total_elems = 0;

fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
__detail::local_sort(lsegment, comp);

/* find splitting values - limits of areas to send to other processes */
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
__detail::splitters<valT>(lsegment, comp, vec_split_i, vec_split_s);
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
default_comm().alltoall(vec_split_s, vec_rsizes, 1);

/* prepare data to send and receive */
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
std::exclusive_scan(vec_rsizes.begin(), vec_rsizes.end(),
vec_rindices.begin(), 0);
const std::size_t _recv_elems = vec_rindices.back() + vec_rsizes.back();
Expand All @@ -382,18 +397,22 @@ void dist_sort(R &r, Compare &&comp) {
* data to achieve size of data equal to size of local segment */

MPI_Request req_recvelems;
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
default_comm().i_all_gather(_recv_elems, vec_recv_elems, &req_recvelems);

/* buffer for received data */
buffer<valT> vec_recvdata(_recv_elems);

/* send data not belonging and receive data belonging to local processes
*/
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata,
vec_rsizes, vec_rindices);

/* TODO: vec recvdata is partially sorted, implementation of merge on GPU is
* desirable */
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);

__detail::local_sort(vec_recvdata, comp);

MPI_Wait(&req_recvelems, MPI_STATUS_IGNORE);
Expand All @@ -419,10 +438,12 @@ void dist_sort(R &r, Compare &&comp) {

/* shift data if necessary, to have exactly the number of elements equal to
* lsegment size */
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
__detail::shift_data<valT>(shift_left, shift_right, vec_recvdata, vec_left,
vec_right);

/* copy results to distributed vector's local segment */
fmt::print("{}:{}\n", default_comm().rank(), __LINE__);
__detail::copy_results<valT>(lsegment, shift_left, shift_right, vec_recvdata,
vec_left, vec_right);

Expand Down
2 changes: 1 addition & 1 deletion test/gtest/mhp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ add_executable(

add_executable(mhp-quick-test
mhp-tests.cpp
../common/inclusive_scan.cpp
../common/sort.cpp
)
# cmake-format: on

Expand Down

0 comments on commit faf565c

Please sign in to comment.