Skip to content

Commit

Permalink
Add GetArgName, GetVariableName and GetResultName functions to XlaCom…
Browse files Browse the repository at this point in the history
…piledCpuFunction that return an arguments' variable's or result's name at a given index.

For AoT-compiled functions, this is only available if static name information was generated (see --gen_name_to_index flag for tfcompile).

PiperOrigin-RevId: 588698517
  • Loading branch information
tensorflower-gardener committed Dec 7, 2023
1 parent cb59a3f commit 5b5eaab
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,34 @@ int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const {
return LookupNameIndex(name, result_names_);
}

const char* XlaCompiledCpuFunction::GetArgName(const int index) const {
assert(arg_names_ != nullptr);
if (index < 0 || index >= num_args_) {
std::cerr << "XlaCompiledCpuFunction::GetArgName: index '" << index
<< "' out of range [0, " << num_args_ << "].\n";
return nullptr;
}
return arg_names_[index];
}

const char* XlaCompiledCpuFunction::GetVariableName(int index) const {
assert(variable_names_ != nullptr);
if (index < 0 || index >= num_variables_) {
std::cerr << "XlaCompiledCpuFunction::GetVariableName: index '" << index
<< "' out of range [0, " << num_variables_ << ").\n";
return nullptr;
}
return variable_names_[index];
}

const char* XlaCompiledCpuFunction::GetResultName(int index) const {
assert(result_names_ != nullptr);
if (index < 0 || index >= num_results_) {
std::cerr << "XlaCompiledCpuFunction::GetResultName: index '" << index
<< "' out of range [0, " << num_results_ << ").\n";
return nullptr;
}
return result_names_[index];
}

} // namespace tensorflow
12 changes: 12 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,18 @@ class XlaCompiledCpuFunction {
// Recommended usage is to capture this in a variable for re-use.
int LookupResultIndex(const string& name) const;

// Returns the name of the argument at `index`.
// Returns nullptr if `HasNameIndices() == false` or `index` is out of range.
const char* GetArgName(int index) const;

// Returns the name of the variable at `index`.
// Returns nullptr if `HasNameIndices() == false` or `index` is out of range.
const char* GetVariableName(int index) const;

// Returns the name of the result at `index`.
// Returns nullptr if `HasNameIndices() == false` or `index` is out of range.
const char* GetResultName(int index) const;

// Returns the shape of the args and results. May return nullptr if the
// program shape isn't available.
const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; }
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,26 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
EXPECT_EQ(0, function.num_variables());
EXPECT_EQ(function.LookupVariableIndex("x"), -1);

// Expect that name and index lookups match.
for (int i = 0; i < function.num_args(); ++i) {
const char* name = function.GetArgName(i);
ASSERT_NE(name, nullptr);
const int roundtrip_i = function.LookupArgIndex(name);
EXPECT_EQ(roundtrip_i, i) << " name= " << name;
}
for (int i = 0; i < function.num_results(); ++i) {
const char* name = function.GetResultName(i);
ASSERT_NE(name, nullptr);
const int roundtrip_i = function.LookupResultIndex(name);
EXPECT_EQ(roundtrip_i, i) << " name= " << name;
}
// Expect correct handling of invalid indices.
EXPECT_EQ(function.GetArgName(-1), nullptr);
EXPECT_EQ(function.GetArgName(function.num_args()), nullptr);
EXPECT_EQ(function.GetResultName(-1), nullptr);
EXPECT_EQ(function.GetResultName(function.num_results()), nullptr);
EXPECT_EQ(function.GetVariableName(0), nullptr);

// Check program shape.
using xla::ShapeUtil;
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
Expand Down Expand Up @@ -263,6 +283,11 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) {
EXPECT_EQ(1, function.num_variables());
EXPECT_EQ(function.LookupVariableIndex("myvar"), 1);

const char* name = function.GetVariableName(0);
EXPECT_EQ(std::string(name), "myvar");
EXPECT_EQ(function.GetVariableName(1), nullptr);
EXPECT_EQ(function.GetVariableName(-1), nullptr);

// Check program shape.
using xla::ShapeUtil;
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
Expand Down

0 comments on commit 5b5eaab

Please sign in to comment.