Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Embedding] Enable construct saver graph when EV Ops are put on GPU. #313

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 201 additions & 4 deletions tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ limitations under the License.

namespace tensorflow {

using GPUDevice = Eigen::GpuDevice;

namespace {
const int64 kEmbeddingVarUseDB = -214;
const int64 kInitializableEmbeddingVarUseDB = -215;
Expand Down Expand Up @@ -967,7 +965,6 @@ REGISTER_KERNELS_ALL_INDEX(float);
#undef REGISTER_KERNELS_ALL_INDEX
#undef REGISTER_KERNELS


#if GOOGLE_CUDA
#if TF_ENABLE_GPU_EV
#define REGISTER_KV_VAR_HANDLE(ktype, vtype) \
Expand Down Expand Up @@ -1335,7 +1332,11 @@ class KvResourceExportOpGPU : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("KvResourceExport") \
.Device(DEVICE_GPU) \
.TypeConstraint<ktype>("Tkeys") \
.TypeConstraint<vtype>("Tvalues"), \
.TypeConstraint<vtype>("Tvalues") \
.HostMemory("keys") \
.HostMemory("values") \
.HostMemory("versions") \
.HostMemory("freqs"), \
KvResourceExportOpGPU<GPUDevice, ktype, vtype>);
#define REGISTER_KERNELS_ALL_INDEX(type) \
REGISTER_KERNELS(int32, type) \
Expand All @@ -1346,6 +1347,202 @@ REGISTER_KERNELS_ALL_INDEX(float);
#undef REGISTER_KERNELS_ALL_INDEX
#undef REGISTER_KERNELS

template <typename TKey, typename TValue>
class KvResourceImportV2OpGPU: public OpKernel {
public:
explicit KvResourceImportV2OpGPU(OpKernelConstruction* c)
: OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
OP_REQUIRES_OK(c, c->GetAttr("counter_type", &counter_type_));
OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
OP_REQUIRES(c, shape_.dims() == 1,
errors::InvalidArgument("KvVariable dimension must be 1"));
OP_REQUIRES_OK(c, c->GetAttr("steps_to_live", &steps_to_live_));
OP_REQUIRES(c, steps_to_live_ >= 0,
errors::InvalidArgument(
"steps_to_live must >= 0, ",
std::to_string(steps_to_live_)));
OP_REQUIRES_OK(c, c->GetAttr("partition_id", &partition_id_));
OP_REQUIRES(c, partition_id_ >= 0,
errors::InvalidArgument(
"partition_id must >= 0, ",
std::to_string(partition_id_)));
OP_REQUIRES_OK(c, c->GetAttr("partition_num", &partition_num_));
OP_REQUIRES(c, partition_num_ >= 1,
errors::InvalidArgument(
"partition_num must >= 1, ",
std::to_string(partition_num_)));
//OP_REQUIRES_OK(c, c->GetAttr("restore_versions", &restore_versions_));
OP_REQUIRES_OK(c, c->GetAttr("ht_type", &ht_type_));
OP_REQUIRES_OK(c, c->GetAttr("ht_partition_num", &ht_partition_num_));
OP_REQUIRES_OK(c, c->GetAttr("emb_index", &emb_index_));
OP_REQUIRES_OK(c, c->GetAttr("slot_index", &slot_index_));
OP_REQUIRES_OK(c, c->GetAttr("filter_freq", &filter_freq_));
OP_REQUIRES_OK(c, c->GetAttr("block_num", &block_num_));
OP_REQUIRES_OK(c, c->GetAttr("max_element_size", &max_element_size_));
OP_REQUIRES_OK(c, c->GetAttr("false_positive_probability",
&false_positive_probability_));
OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold",
&l2_weight_threshold_));
OP_REQUIRES_OK(c, c->GetAttr("layout", &layout_));
OP_REQUIRES_OK(c, c->GetAttr("max_freq", &max_freq_));
OP_REQUIRES_OK(c, c->GetAttr("default_value_dim",
&default_value_dim_));
OP_REQUIRES_OK(c, c->GetAttr("slot_num", &slot_num_));
int64 storage_type = 0;
OP_REQUIRES_OK(c, c->GetAttr("storage_type", &storage_type));
storage_type_ = static_cast<embedding::StorageType>(storage_type);

OP_REQUIRES_OK(c, c->GetAttr("storage_path", &storage_path_));
OP_REQUIRES_OK(c, c->GetAttr("storage_size", &storage_size_));
OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_));
OP_REQUIRES_OK(c, c->GetAttr("record_version", &record_version_));
}

void Compute(OpKernelContext* context) override {
const Tensor& file_name = context->input(0);
const std::string file_name_string = file_name.scalar<string>()();
const Tensor& name = context->input(4);
const std::string name_string = name.scalar<string>()();
const Tensor& default_values = context->input(3);
OP_REQUIRES(context, dtype_ == default_values.dtype(),
errors::InvalidArgument(
"Variable and ddd value dtypes don't match; respectively, ",
dtype_, " and ", default_values.dtype()));

ResourceHandle handle_self = HandleFromInput(context, 1);
ResourceHandle handle_primary = HandleFromInput(context, 2);
std::string opname = handle_self.name();
EmbeddingVarGPU<TKey, TValue>* ev = nullptr;
if (handle_self.name() == handle_primary.name() &&
handle_self.container() == handle_primary.container()) {
OP_REQUIRES_OK(
context,
LookupOrCreateResource<EmbeddingVarGPU<TKey, TValue>>(
context, handle_self, &ev,
[this, default_values, opname, context,
handle_self](EmbeddingVarGPU<TKey, TValue>** ptr) {
GPUHashTable<TKey, TValue>* ht =
new GPUHashTable<TKey, TValue>(-1,
context->get_allocator(AllocatorAttributes()));
*ptr = new EmbeddingVarGPU<TKey, TValue>(handle_self.name(),
ht, context->get_allocator(AllocatorAttributes()),
EmbeddingConfig(emb_index_ + block_num_ * slot_index_,
emb_index_,
block_num_, slot_num_,
opname + "-primary",
steps_to_live_, filter_freq_, max_freq_,
l2_weight_threshold_, layout_,
max_element_size_,
false_positive_probability_,
counter_type_, default_value_dim_));
return (*ptr)->Init(default_values, default_value_dim_);
}));
} else {
EmbeddingVarGPU<TKey, TValue>* primary_variable = nullptr;

OP_REQUIRES_OK(
context,
LookupOrCreateResource<EmbeddingVarGPU<TKey, TValue>>(
context, handle_primary, &primary_variable,
[this, default_values, opname, context,
handle_primary](EmbeddingVarGPU<TKey, TValue>** ptr) {
int64 primary_slot_index(0), primary_emb_index(0);
GPUHashTable<TKey, TValue>* ht =
new GPUHashTable<TKey, TValue>(-1,
context->get_allocator(AllocatorAttributes()));
*ptr = new EmbeddingVarGPU<TKey, TValue>(handle_primary.name(),
ht, context->get_allocator(AllocatorAttributes()),
EmbeddingConfig(
primary_emb_index + block_num_ * primary_slot_index,
primary_emb_index,
block_num_, slot_num_, opname + "-primary",
steps_to_live_, filter_freq_, max_freq_,
l2_weight_threshold_, layout_,
max_element_size_,
false_positive_probability_,
counter_type_));
return (*ptr)->Init(default_values, default_value_dim_);
}));


OP_REQUIRES_OK(
context,
LookupOrCreateResource<EmbeddingVarGPU<TKey, TValue>>(
context, handle_self, &ev,
[this, default_values, opname, primary_variable, context,
handle_self](EmbeddingVarGPU<TKey, TValue>** ptr) {
*ptr = new EmbeddingVarGPU<TKey, TValue>(handle_self.name(),
primary_variable->kv(),
context->get_allocator(AllocatorAttributes()),
EmbeddingConfig(emb_index_ + block_num_ * slot_index_,
emb_index_,
block_num_, slot_num_, opname,
steps_to_live_, 0,
max_freq_, l2_weight_threshold_,
layout_, 0, -1.0, counter_type_, default_value_dim_));
return (*ptr)->Init(default_values, default_value_dim_);
}));
core::ScopedUnref unref_me(primary_variable);
}
core::ScopedUnref unref_me(ev);

BundleReader reader(Env::Default(), file_name_string);
auto s = reader.status();
if (!s.ok()) {
LOG(FATAL) << "Restore EV failure, create BundleReader error:"
<< s.ToString();
}

EVRestoreDynamicallyGPU(
ev, name_string, partition_id_, partition_num_, context, &reader,
"-partition_offset", "-keys", "-values", "-versions", "-freqs");
ev->SetInitialized();
}

private:
int64 partition_id_;
int64 partition_num_;
DataType dtype_;
DataType counter_type_;
int64 max_element_size_;
float false_positive_probability_;
TensorShape shape_;
int64 steps_to_live_;
bool restore_versions_;
string ht_type_;
int64 ht_partition_num_;
int64 emb_index_;
int64 slot_index_;
int64 block_num_;
int64 slot_num_;
int64 filter_freq_;
float l2_weight_threshold_;
std::string layout_;
int64 max_freq_;
embedding::StorageType storage_type_;
std::string storage_path_;
std::vector<int64> storage_size_;
int64 default_value_dim_;
bool record_freq_;
bool record_version_;
};

#define REGISTER_KERNELS(ktype, vtype) \
REGISTER_KERNEL_BUILDER(Name("KvResourceImportV2") \
.Device(DEVICE_GPU) \
.TypeConstraint<ktype>("Tkeys") \
.TypeConstraint<vtype>("dtype"), \
KvResourceImportV2OpGPU<ktype, vtype>);
#define REGISTER_KERNELS_ALL_INDEX(type) \
REGISTER_KERNELS(int32, type) \
REGISTER_KERNELS(int64, type)
TF_CALL_float(REGISTER_KERNELS_ALL_INDEX);
TF_CALL_double(REGISTER_KERNELS_ALL_INDEX);
//TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS_ALL_INDEX);
#undef REGISTER_KERNELS_ALL_INDEX
#undef REGISTER_KERNELS

#endif // TF_ENABLE_GPU_EV
#endif // GOOGLE_CUDA

Expand Down
Loading