Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RAFT wrappers around current_device_resource functions #2424

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ int main(int argc, char** argv)
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto old_mr = raft::resource::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
raft::resource::set_current_device_resource(old_mr);
return ret;
}
#endif
6 changes: 3 additions & 3 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class shared_raft_resources {
using large_mr_type = rmm::mr::managed_memory_resource;

shared_raft_resources()
try : orig_resource_{rmm::mr::get_current_device_resource()},
try : orig_resource_{raft::resource::get_current_device_resource()},
pool_resource_(orig_resource_, 1024 * 1024 * 1024ull),
resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() {
rmm::mr::set_current_device_resource(&resource_);
raft::resource::set_current_device_resource(&resource_);
} catch (const std::exception& e) {
auto cuda_status = cudaGetLastError();
size_t free = 0;
Expand All @@ -103,7 +103,7 @@ class shared_raft_resources {
shared_raft_resources(const shared_raft_resources& res) = delete;
shared_raft_resources& operator=(const shared_raft_resources& other) = delete;

~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); }
~shared_raft_resources() noexcept { raft::resource::set_current_device_resource(orig_resource_); }

auto get_large_memory_resource() noexcept
{
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ int main(int argc, char** argv)
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto old_mr = raft::resource::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
raft::resource::set_current_device_resource(old_mr);
return ret;
}
#endif
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
switch (mem_type) {
case (AllocatorType::HostPinned): return &mr_pinned_;
case (AllocatorType::HostHugePage): return &mr_huge_page_;
default: return rmm::mr::get_current_device_resource();
default: return raft::resource::get_current_device_resource_ref();
}
}
};
Expand Down
11 changes: 6 additions & 5 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/device_resources.hpp>
#include <raft/core/interruptible.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/random/make_blobs.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -50,20 +51,20 @@ struct using_pool_memory_res {

public:
using_pool_memory_res(size_t initial_size, size_t max_size)
: orig_res_(rmm::mr::get_current_device_resource()),
: orig_res_(raft::resource::get_current_device_resource()),
pool_res_(&cuda_res_, initial_size, max_size)
{
rmm::mr::set_current_device_resource(&pool_res_);
raft::resource::set_current_device_resource(&pool_res_);
}

using_pool_memory_res()
: orig_res_(rmm::mr::get_current_device_resource()),
: orig_res_(raft::resource::get_current_device_resource()),
pool_res_(&cuda_res_, rmm::percent_of_free_device_memory(50))
{
rmm::mr::set_current_device_resource(&pool_res_);
raft::resource::set_current_device_resource(&pool_res_);
}

~using_pool_memory_res() { rmm::mr::set_current_device_resource(orig_res_); }
~using_pool_memory_res() { raft::resource::set_current_device_resource(orig_res_); }
};

/**
Expand Down
8 changes: 4 additions & 4 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ template <typename T, typename MapT, typename IdxT, bool Conditional = false>
struct Gather : public fixture {
Gather(const GatherParams<IdxT>& p)
: params(p),
old_mr(rmm::mr::get_current_device_resource()),
pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)),
old_mr(raft::resource::get_current_device_resource()),
pool_mr(raft::resource::get_current_device_resource_ref(), 2 * (1ULL << 30)),
matrix(this->handle),
map(this->handle),
out(this->handle),
stencil(this->handle),
matrix_h(this->handle)
{
rmm::mr::set_current_device_resource(&pool_mr);
raft::resource::set_current_device_resource(&pool_mr);
}

~Gather() { rmm::mr::set_current_device_resource(old_mr); }
~Gather() { raft::resource::set_current_device_resource(old_mr); }

void allocate_data(const ::benchmark::State& state) override
{
Expand Down
21 changes: 8 additions & 13 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <raft/core/bitset.cuh>
#include <raft/core/resource/device_id.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/sample_filter.cuh>
Expand All @@ -29,13 +30,13 @@

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/sequence.h>

#include <memory>
#include <optional>

namespace raft::bench::spatial {
Expand Down Expand Up @@ -89,25 +90,19 @@ inline auto operator<<(std::ostream& os, const Scope& s) -> std::ostream&

struct device_resource {
public:
explicit device_resource(bool managed) : managed_(managed)
explicit device_resource(bool managed)
: managed_(managed ? std::make_shared<rmm::mr::managed_memory_resource>() : nullptr),
res_(managed ? managed_.get() : raft::resource::get_current_device_resource_ref())
{
if (managed_) {
res_ = new rmm::mr::managed_memory_resource();
} else {
res_ = rmm::mr::get_current_device_resource();
}
}

~device_resource()
{
if (managed_) { delete res_; }
}
~device_resource() = default;

[[nodiscard]] auto get() const -> rmm::device_async_resource_ref { return res_; }

private:
const bool managed_;
rmm::mr::device_memory_resource* res_;
std::shared_ptr<rmm::mr::device_memory_resource> managed_;
rmm::device_async_resource_ref res_;
};

template <typename T>
Expand Down
7 changes: 4 additions & 3 deletions cpp/bench/prims/neighbors/refine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/detail/refine.cuh>
#include <raft/neighbors/refine.cuh>
Expand Down Expand Up @@ -58,10 +59,10 @@ class RefineAnn : public fixture {
label_stream << data.p;
state.SetLabel(label_stream.str());

auto old_mr = rmm::mr::get_current_device_resource();
auto old_mr = raft::resource::get_current_device_resource();
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr(
old_mr, rmm::percent_of_free_device_memory(50));
rmm::mr::set_current_device_resource(&pool_mr);
raft::resource::set_current_device_resource(&pool_mr);

if (data.p.host_data) {
loop_on_state(state, [this]() {
Expand All @@ -84,7 +85,7 @@ class RefineAnn : public fixture {
data.p.metric);
});
}
rmm::mr::set_current_device_resource(old_mr);
raft::resource::set_current_device_resource(old_mr);
}

private:
Expand Down
9 changes: 5 additions & 4 deletions cpp/bench/prims/random/subsample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/random/permute.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/sample_without_replacement.cuh>
Expand Down Expand Up @@ -64,16 +65,16 @@ template <typename T>
struct sample : public fixture {
sample(const sample_inputs& p)
: params(p),
old_mr(rmm::mr::get_current_device_resource()),
pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB),
old_mr(raft::resource::get_current_device_resource()),
pool_mr(raft::resource::get_current_device_resource(), 2 * GiB),
in(make_device_vector<T, int64_t>(res, p.n_samples)),
out(make_device_vector<T, int64_t>(res, p.n_train))
{
rmm::mr::set_current_device_resource(&pool_mr);
raft::resource::set_current_device_resource(&pool_mr);
raft::random::RngState r(123456ULL);
}

~sample() { rmm::mr::set_current_device_resource(old_mr); }
~sample() { raft::resource::set_current_device_resource(old_mr); }
void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
Expand Down
4 changes: 1 addition & 3 deletions cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/device_ptr.h>

Expand Down Expand Up @@ -185,7 +183,7 @@ class device_uvector_policy {
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }

private:
rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource()};
rmm::device_async_resource_ref mr_{raft::resource::get_current_device_resource_ref()};
};

} // namespace raft
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/core/device_resources_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
#pragma once
#include <raft/core/device_resources.hpp>
#include <raft/core/device_setter.hpp>
#include <raft/core/resource/device_memory_resource.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/mr/device/cuda_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <algorithm>
#include <memory>
Expand Down Expand Up @@ -169,14 +169,14 @@ struct device_resources_manager {
// resource
if (params.max_mem_pool_size.value_or(1) != 0) {
auto* upstream =
dynamic_cast<rmm::mr::cuda_memory_resource*>(rmm::mr::get_current_device_resource());
dynamic_cast<rmm::mr::cuda_memory_resource*>(resource::get_current_device_resource());
if (upstream != nullptr) {
result =
std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
upstream,
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
params.max_mem_pool_size);
rmm::mr::set_current_device_resource(result.get());
raft::resource::set_current_device_resource(result.get());
} else {
RAFT_LOG_WARN(
"Pool allocation requested, but other memory resource has already been set and "
Expand Down
Loading
Loading