Skip to content

Commit

Permalink
Merge branch 'fea-deeprec_may_update-aleliu' into 'integration'
Browse files Browse the repository at this point in the history
[Ready for merge] sok update

See merge request dl/hugectr/hugectr!1351
  • Loading branch information
minseokl committed May 18, 2023
2 parents ce20a63 + 0aa6b64 commit 201ad20
Show file tree
Hide file tree
Showing 17 changed files with 498 additions and 275 deletions.
6 changes: 3 additions & 3 deletions HugeCTR/core/hctr_impl/hctr_backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <cuda_runtime.h>

#include <core/core.hpp>
#include <resource_manager.hpp>
#include <resource_manager_base.hpp>
#include <unordered_map>

namespace hctr_internal {
Expand All @@ -44,15 +44,15 @@ class GPUResource final : public core::GPUResourceBase {
};

class HCTRCoreResourceManager : public core::CoreResourceManager {
std::shared_ptr<HugeCTR::ResourceManager> ext_;
std::shared_ptr<HugeCTR::ResourceManagerBase> ext_;
int local_id_;
int global_id_;
int device_id_;

std::shared_ptr<core::GPUResourceBase> gpu_resource_;

public:
HCTRCoreResourceManager(std::shared_ptr<HugeCTR::ResourceManager> ext, int local_id)
HCTRCoreResourceManager(std::shared_ptr<HugeCTR::ResourceManagerBase> ext, int local_id)
: ext_(ext),
local_id_(local_id),
global_id_(ext_->get_gpu_global_id_from_local_id(local_id)),
Expand Down
36 changes: 19 additions & 17 deletions HugeCTR/core23/registry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,27 @@
break; \
}

#define DISPATCH_INTEGRAL_FUNCTION_CORE23(DATA_TYPE, HINT, ...) \
switch (DATA_TYPE) { \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int64, int64_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int32, int32_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::UInt64, uint64_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::UInt32, uint32_t, HINT, __VA_ARGS__) \
default: \
HCTR_LOG_S(ERROR, WORLD) << "DataType is" << DATA_TYPE << std::endl; \
HCTR_OWN_THROW(HugeCTR::Error_t::IllegalCall, \
"DISPATCH_INTEGRAL_FUNCTION do not such support type"); \
#define DISPATCH_INTEGRAL_FUNCTION_CORE23(DATA_TYPE, HINT, ...) \
switch (DATA_TYPE) { \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int64, int64_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int32, int32_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::UInt64, uint64_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::UInt32, uint32_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::LongLong, int64_t, HINT, __VA_ARGS__) \
default: \
HCTR_LOG_S(ERROR, WORLD) << "DataType is" << DATA_TYPE << std::endl; \
HCTR_OWN_THROW(HugeCTR::Error_t::IllegalCall, \
"DISPATCH_INTEGRAL_FUNCTION do not such support type"); \
}

#define DISPATCH_SIGNED_INTEGRAL_FUNCTION_CORE23(DATA_TYPE, HINT, ...) \
switch (DATA_TYPE) { \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int64, int64_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int32, int32_t, HINT, __VA_ARGS__) \
default: \
HCTR_OWN_THROW(HugeCTR::Error_t::IllegalCall, \
"DISPATCH_SIGNED_INTEGRAL_FUNCTION do not support type"); \
#define DISPATCH_SIGNED_INTEGRAL_FUNCTION_CORE23(DATA_TYPE, HINT, ...) \
switch (DATA_TYPE) { \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int64, int64_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::Int32, int32_t, HINT, __VA_ARGS__) \
CASE_TYPE_USING_HINT_CORE23(core23::ScalarType::LongLong, int64_t, HINT, __VA_ARGS__) \
default: \
HCTR_OWN_THROW(HugeCTR::Error_t::IllegalCall, \
"DISPATCH_SIGNED_INTEGRAL_FUNCTION do not support type"); \
}

#define DISPATCH_UNSIGNED_INTEGRAL_FUNCTION_CORE23(DATA_TYPE, HINT, ...) \
Expand Down
1 change: 0 additions & 1 deletion HugeCTR/embedding/all2all_embedding_collection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <cub/cub.cuh>
#include <embedding/all2all_embedding_collection.hpp>
#include <embedding/data_distributor/data_distributor.hpp>
#include <embeddings/embedding_collection.hpp>
#include <utils.hpp>

namespace embedding {
Expand Down
1 change: 0 additions & 1 deletion HugeCTR/embedding/data_distributor/data_distributor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <embedding/operators/mp_index_calculation.hpp>
#include <embedding/operators/transpose_input.hpp>
#include <optional>
#include <resource_manager.hpp>
#include <unordered_map>
#include <vector>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <HugeCTR/embedding/common.hpp>
#include <HugeCTR/include/utils.cuh>
#include <HugeCTR/include/utils.hpp>
#include <embedding/data_distributor/data_distribution_op.hpp>
#include <embedding/operators/communication.hpp>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cuda_runtime.h>

#include <utils.cuh>
#include <utils.hpp>

#include "key_filtering_operators.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <core/core.hpp>
#include <embedding/common.hpp>
#include <memory>
#include <resource_manager.hpp>
#include <unordered_map>
#include <vector>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <HugeCTR/embedding/common.hpp>
#include <HugeCTR/include/utils.cuh>
#include <HugeCTR/include/utils.hpp>
#include <cub/cub.cuh>
#include <embedding/data_distributor/data_distribution_op.hpp>
#include <embedding/operators/communication.hpp>
Expand Down
1 change: 0 additions & 1 deletion HugeCTR/embedding/operators/keys_to_indices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#pragma once
#include <HugeCTR/embedding/common.hpp>
#include <HugeCTR/embedding_storage/common.hpp>
#include <resource_manager.hpp>
namespace embedding {
class KeysToIndicesConverter {
private:
Expand Down
3 changes: 0 additions & 3 deletions HugeCTR/include/resource_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@ class ResourceManager : public ResourceManagerBase {
virtual int get_master_process_id() const = 0;
virtual bool is_master_process() const = 0;
virtual const std::shared_ptr<CPUResource>& get_local_cpu() const = 0;
virtual const std::vector<int>& get_local_gpu_device_id_list() const = 0;
const virtual std::vector<std::shared_ptr<GPUResource>>& get_local_gpus() const = 0;
virtual int get_process_id_from_gpu_global_id(size_t global_gpu_id) const = 0;
virtual size_t get_gpu_local_id_from_global_id(size_t global_gpu_id) const = 0;
virtual size_t get_gpu_global_id_from_local_id(size_t local_gpu_id) const = 0;
virtual bool p2p_enabled(int src_dev, int dst_dev) const = 0;
virtual bool all_p2p_enabled() const = 0;

Expand Down
3 changes: 3 additions & 0 deletions HugeCTR/include/resource_manager_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class ResourceManagerBase {
size_t device_id) const = 0;
virtual size_t get_local_gpu_count() const = 0;
virtual size_t get_global_gpu_count() const = 0;
virtual size_t get_gpu_local_id_from_global_id(size_t global_gpu_id) const = 0;
virtual size_t get_gpu_global_id_from_local_id(size_t local_gpu_id) const = 0;
virtual const std::vector<int>& get_local_gpu_device_id_list() const = 0;
};

} // namespace HugeCTR
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,64 @@ TFAdapter<KeyType, OffsetType, DType>::TFAdapter()
CUDACHECK(cudaDeviceGetAttribute(&sm_count_, cudaDevAttrMultiProcessorCount, device));
}

template <typename KeyType, typename OffsetType, typename DType>
void TFAdapter<KeyType, OffsetType, DType>::set(std::vector<float*>& vars,
std::vector<int>& dimensions,
std::vector<int>& scale, cudaStream_t stream) {
std::vector<float*> data;
std::vector<int> id_space;
for (int i = 0; i < vars.size(); ++i) {
float* input = vars[i];
data.push_back(input);
id_space.push_back(i);
}

if (data_.size() == data.size()) {
bool skip = true;
for (int i = 0; i < data.size(); ++i) {
if (data_[i] != data[i]) {
skip = false;
break;
}
}
if (skip) {
return;
}
}

data_ = data;
dimensions_ = dimensions;
scale_ = scale;
stream_ = stream;

id_space_to_local_index_.resize(vars.size(), -1);
for (int i = 0; i < id_space.size(); ++i) {
id_space_to_local_index_[id_space[i]] = i;
}

free();

CUDACHECK(cudaMalloc(&d_data_, sizeof(float*) * data_.size()));
CUDACHECK(cudaMalloc(&d_dimensions_, sizeof(int) * dimensions_.size()));
CUDACHECK(cudaMalloc(&d_id_space_to_local_index_, sizeof(int) * id_space_to_local_index_.size()));
CUDACHECK(cudaMalloc(&d_scale_, sizeof(int) * scale_.size()));

// clang-format off
CUDACHECK(cudaMemcpyAsync(d_data_, data_.data(),
sizeof(float*) * data_.size(),
cudaMemcpyHostToDevice, stream_));
CUDACHECK(cudaMemcpyAsync(d_dimensions_, dimensions_.data(),
sizeof(int) * dimensions_.size(),
cudaMemcpyHostToDevice, stream_));
CUDACHECK(cudaMemcpyAsync(d_id_space_to_local_index_, id_space_to_local_index_.data(),
sizeof(int) * id_space_to_local_index_.size(),
cudaMemcpyHostToDevice, stream_));
CUDACHECK(cudaMemcpyAsync(d_scale_, scale_.data(),
sizeof(int) * scale_.size(),
cudaMemcpyHostToDevice, stream_));
// clang-format on
}

template <typename KeyType, typename OffsetType, typename DType>
void TFAdapter<KeyType, OffsetType, DType>::set(
std::vector<tensorflow::core::RefCountPtr<tensorflow::Var>>& vars,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,18 @@

namespace sok {
namespace core23 = HugeCTR::core23;

template <typename KeyType, typename OffsetType, typename DType>
class TFAdapter : public ::embedding::ILookup {
public:
TFAdapter();
virtual ~TFAdapter();

// for Variable
void set(std::vector<float*>& vars, std::vector<int>& dimensions, std::vector<int>& scale,
cudaStream_t stream = 0);

// for ResourceVariable
void set(std::vector<tensorflow::core::RefCountPtr<tensorflow::Var>>& vars,
std::vector<tensorflow::tf_shared_lock>& locks, std::vector<int>& dimensions,
std::vector<int>& scale, cudaStream_t stream = 0);
Expand All @@ -60,7 +66,6 @@ class TFAdapter : public ::embedding::ILookup {
int* d_id_space_to_local_index_;
int* d_scale_;
cudaStream_t stream_;

void free();
};

Expand Down
Loading

0 comments on commit 201ad20

Please sign in to comment.