Skip to content

Commit

Permalink
Internal change only
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689985452
  • Loading branch information
SiqiaoWu1993 authored and tensorflower-gardener committed Oct 26, 2024
1 parent a10e9b3 commit a74969e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 11 deletions.
18 changes: 9 additions & 9 deletions tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,15 @@ uint64_t MlirModuleFingerprint(mlir::ModuleOp module) {
} // namespace

absl::StatusOr<std::string> Tf2HloArg::Key() {
if (!topology) {
return absl::InternalError("topology is not set");
uint64_t fingerprint = tsl::Fingerprint64(platform_name);
if (topology) {
TF_ASSIGN_OR_RETURN(std::string serialized_topology, topology->Serialize());
fingerprint = tsl::Fingerprint64(serialized_topology);
}
if (platform_name != xla::CudaName() && !topology) {
return absl::FailedPreconditionError(
"Topology is required for non-GPU compilation.");
}
TF_ASSIGN_OR_RETURN(std::string serialized_topology, topology->Serialize());
uint64_t fingerprint = tsl::Fingerprint64(serialized_topology);
fingerprint =
tsl::FingerprintCat64(fingerprint, MlirModuleFingerprint(module));
for (const auto& dtype_and_shape : input_dtypes_and_shapes) {
Expand Down Expand Up @@ -205,14 +209,10 @@ absl::StatusOr<Tf2HloResult> CompileTfToHlo(const Tf2HloArg& arg) {
tensorflow::DumpMlirOpToFile("ifrt_before_bridge_phase2", arg.module);
}

if (!arg.topology) {
return absl::InternalError("topology is not set in Tf2HloArg");
}

// Device_type is a string of
// tensorflow/compiler/mlir/tf2xla/api/v2/device_type.proto:DeviceType
std::string device_type = "XLA_TPU_JIT";
if (arg.topology->platform_name() == xla::CudaName()) {
if (arg.platform_name == xla::CudaName()) {
device_type = "XLA_GPU_JIT";
}
VLOG(1) << "device_type: " << device_type;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct Tf2HloArg {
tensorflow::tpu::TPUCompileMetadataProto compile_metadata;
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn;
std::shared_ptr<xla::ifrt::Topology> topology;
absl::string_view platform_name;

absl::StatusOr<std::string> Key();
};
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ TEST(Tf2HloTest, Empty) {
.compile_metadata = compile_metadata,
.shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(),
.topology = std::make_shared<xla::ifrt::PjRtTopology>(cpu_topology_ptr),
.platform_name = xla::CpuName(),
};
auto result = CompileTfToHlo(arg);

Expand Down Expand Up @@ -180,6 +181,7 @@ TEST(Tf2HloTest, Tuple) {
.compile_metadata = compile_metadata,
.shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(),
.topology = std::make_shared<xla::ifrt::PjRtTopology>(cpu_topology_ptr),
.platform_name = xla::CpuName(),
};

auto result = CompileTfToHlo(arg);
Expand Down Expand Up @@ -233,6 +235,7 @@ TEST(Tf2HloTest, Spmd) {
.compile_metadata = compile_metadata,
.shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(),
.topology = std::make_shared<xla::ifrt::PjRtTopology>(cpu_topology_ptr),
.platform_name = xla::CpuName(),
};

auto result = CompileTfToHlo(arg);
Expand Down Expand Up @@ -324,6 +327,7 @@ TEST(Tf2HloTest, UsingDefaultDeviceAssignment) {
.compile_metadata = compile_metadata,
.shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(),
.topology = std::make_shared<xla::ifrt::PjRtTopology>(cpu_topology_ptr),
.platform_name = xla::CpuName(),
};

auto result = CompileTfToHlo(arg);
Expand Down Expand Up @@ -440,6 +444,7 @@ TEST(Tf2HloTest, XlaCallHostCallback) {
.compile_metadata = compile_metadata,
.shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(),
.topology = std::make_shared<xla::ifrt::PjRtTopology>(cpu_topology_ptr),
.platform_name = xla::CpuName(),
};

auto result = CompileTfToHlo(arg);
Expand Down Expand Up @@ -497,6 +502,7 @@ TEST(Tf2HloTest, GpuCompile) {
.topology = std::make_shared<xla::ifrt::PjRtTopology>(
std::make_shared<xla::StreamExecutorGpuTopologyDescription>(
xla::CudaId(), xla::CudaName(), /*gpu_topology=*/nullptr)),
.platform_name = xla::CudaName(),
};

auto result = CompileTfToHlo(arg);
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/tfrt/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ cc_library(
"@local_xla//xla/hlo/ir:hlo",
"@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"@local_xla//xla/pjrt:host_callback",
"@local_xla//xla/pjrt:pjrt_compiler",
"@local_xla//xla/pjrt:pjrt_executable",
"@local_xla//xla/python/ifrt",
"@local_xla//xla/python/ifrt/hlo:hlo_program",
Expand Down
11 changes: 9 additions & 2 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/client.h"
Expand Down Expand Up @@ -423,9 +424,15 @@ IfrtServingExecutable::CreateExecutableSynchronously(
.entry_function_name = signature_name(),
.compile_metadata = compile_metadata,
.shape_representation_fn = shape_representation_fn_,
.platform_name = ifrt_client_->platform_name(),
};
TF_ASSIGN_OR_RETURN(tf2hlo_arg.topology, ifrt_client_->GetTopologyForDevices(
assigned_device_list_));

if (tf2hlo_arg.platform_name != xla::CudaName()) {
TF_ASSIGN_OR_RETURN(
tf2hlo_arg.topology,
ifrt_client_->GetTopologyForDevices(assigned_device_list_));
}

TF_ASSIGN_OR_RETURN(Tf2HloResult tf2hlo_result,
persistent_compilation_cache_->LookupTf2HloResultOrCreate(
tf2hlo_arg, assigned_device_list_));
Expand Down

0 comments on commit a74969e

Please sign in to comment.