From 5ab99bcd39a484ca04d840f0d3a069f0b600ffeb Mon Sep 17 00:00:00 2001 From: Toli Yevtushenko Date: Mon, 18 Nov 2024 22:29:13 -0800 Subject: [PATCH] Better encapsulation of HloModuleConfig's fields through setters and returning references instead of pointers. PiperOrigin-RevId: 697878690 --- .../tpu/kernels/tpu_compile_op_support.cc | 2 +- third_party/xla/xla/service/BUILD | 13 ++++++++++ third_party/xla/xla/service/dump_test.cc | 2 +- third_party/xla/xla/service/gpu/BUILD | 2 ++ .../xla/service/gpu/gpu_hlo_schedule_test.cc | 2 +- .../gpu/gpu_latency_hiding_scheduler_test.cc | 6 ++++- .../xla/xla/service/gpu/transforms/BUILD | 2 +- .../transforms/pgle_accuracy_checker_test.cc | 8 +++---- .../xla/xla/service/hlo_module_config.cc | 8 +++++-- .../xla/xla/service/hlo_module_config.h | 24 +++++++++++++++---- .../xla/xla/service/hlo_module_config_test.cc | 1 + .../xla/xla/service/hlo_module_util.cc | 13 ++++++++-- .../xla/xla/service/instruction_fusion.cc | 3 +-- 13 files changed, 66 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc index ccf18d82c8dd74..f99342d8ce3569 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc @@ -150,7 +150,7 @@ absl::StatusOr> CreateModuleConfig( if (fusion_config_collection != nullptr && fusion_config != nullptr && *fusion_config_collection != xla::FusionConfigCollection::kOff) { config->set_fusion_config_collection(*fusion_config_collection); - *config->mutable_fusion_config() = *fusion_config; + config->set_fusion_config(*fusion_config); } return std::move(config); diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index a66d0875ccb813..9f91be8a225fe8 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1768,11 +1768,19 @@ cc_library( hdrs = ["hlo_module_util.h"], deps = [ ":compiler", + ":computation_layout", ":hlo_module_config", + "//xla:debug_options_flags", + "//xla:shape_layout", "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -4312,12 +4320,16 @@ cc_library( ":hlo_proto_cc", "//xla:debug_options_flags", "//xla:shape_layout", + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", ], @@ -4331,6 +4343,7 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/tests:test_utils", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/dump_test.cc b/third_party/xla/xla/service/dump_test.cc index 7d9e4c794a3141..7d8e6d79b3cbbb 100644 --- a/third_party/xla/xla/service/dump_test.cc +++ b/third_party/xla/xla/service/dump_test.cc @@ -154,7 +154,7 @@ TEST(DumpTest, DumpProtobufToFileWhenDisabled) { TEST(DumpTest, DumpFdoProfileToFileWhenEnabled) { std::string fdo_profile = "fdo_profile"; HloModuleConfig config; - *config.mutable_fdo_profile() = fdo_profile; + config.set_fdo_profile(fdo_profile); DebugOptions options = config.debug_options(); auto env = tsl::Env::Default(); std::string dump_dir; diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 027f317d1d2e5c..bb40e8fc87e52e 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3091,9 +3091,11 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index e2c185241ab0dc..7be4373e6fc49e 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -86,7 +86,7 @@ class GpuHloScheduleTest : public HloTestBase { debug_options.set_xla_gpu_lhs_enable_gpu_async_tracker( enable_gpu_async_tracker); config.set_debug_options(debug_options); - *config.mutable_fdo_profile() = fdo_profile; + config.set_fdo_profile(fdo_profile); return config; } diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 42e05cf9db71cd..d0603cdf1e4695 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -15,13 +15,17 @@ limitations under the License. #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include #include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_hlo_schedule.h" @@ -78,7 +82,7 @@ class GpuLatencyHidingSchedulerBaseTest : public HloTestBase { debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true); debug_options.set_xla_gpu_lhs_enable_gpu_async_tracker(true); config.set_debug_options(debug_options); - *config.mutable_fdo_profile() = fdo_profile; + config.set_fdo_profile(fdo_profile); return config; } }; diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index d518aecdb0cb93..cc081a35692ce3 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -3312,10 +3312,10 @@ xla_cc_test( deps = [ ":pgle_accuracy_checker", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/service:latency_hiding_scheduler", "//xla/service:profile_guided_latency_estimator", "//xla/service/gpu:gpu_latency_hiding_scheduler", - "//xla/tests:hlo_test_base", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc index d1994a1f6f0687..5e24b8e3aab875 100644 --- a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc @@ -24,10 +24,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/latency_hiding_scheduler.h" #include "xla/service/profile_guided_latency_estimator.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -35,7 +35,7 @@ limitations under the License. namespace xla::gpu { namespace { -using PGLEAccuracyCheckerTest = HloTestBase; +using PGLEAccuracyCheckerTest = HloHardwareIndependentTestBase; using ::tensorflow::profiler::ProfiledInstructionsProto; using ::tsl::protobuf::TextFormat; using ::tsl::testing::StatusIs; @@ -95,7 +95,7 @@ TEST_F(PGLEAccuracyCheckerTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - *module->mutable_config().mutable_fdo_profile() = kProfileString; + module->mutable_config().set_fdo_profile(kProfileString); auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile); PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator); @@ -147,7 +147,7 @@ TEST_F(PGLEAccuracyCheckerTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - *module->mutable_config().mutable_fdo_profile() = kProfileString; + module->mutable_config().set_fdo_profile(kProfileString); module->mutable_config() .mutable_debug_options() .set_xla_gpu_pgle_accuracy_checker( diff --git a/third_party/xla/xla/service/hlo_module_config.cc b/third_party/xla/xla/service/hlo_module_config.cc index 4b4a1cae218586..85a483238dcef3 100644 --- a/third_party/xla/xla/service/hlo_module_config.cc +++ b/third_party/xla/xla/service/hlo_module_config.cc @@ -24,11 +24,15 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" +#include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/xla.pb.h" #include "tsl/platform/statusor.h" @@ -217,7 +221,7 @@ static void AssignStructFusionConfig(HloModuleConfig& config, } module_config.push_back(std::move(temp)); } - *config.mutable_fusion_config() = std::move(module_config); + config.set_fusion_config(std::move(module_config)); } static void AssignStructDotConfig(HloModuleConfig& config, @@ -259,7 +263,7 @@ static void AssignStructPhaseOrderingConfig(HloModuleConfig& config, } module_config.push_back(std::move(temp)); } - *config.mutable_phase_ordering_config() = std::move(module_config); + config.set_phase_ordering_config(std::move(module_config)); } HloModuleConfigProto HloModuleConfig::ToProto() const { diff --git a/third_party/xla/xla/service/hlo_module_config.h b/third_party/xla/xla/service/hlo_module_config.h index 8fce19e8baa547..0644b888528469 100644 --- a/third_party/xla/xla/service/hlo_module_config.h +++ b/third_party/xla/xla/service/hlo_module_config.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_MODULE_CONFIG_H_ #define XLA_SERVICE_HLO_MODULE_CONFIG_H_ +#include #include #include #include @@ -25,11 +26,15 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" @@ -325,8 +330,11 @@ class HloModuleConfig { const std::vector>& fusion_config() const { return fusion_config_; } - std::vector>* mutable_fusion_config() { - return &fusion_config_; + void set_fusion_config(std::vector> fusion_config) { + fusion_config_ = std::move(fusion_config); + } + std::vector>& mutable_fusion_config() { + return fusion_config_; } const absl::flat_hash_map>& dot_config() @@ -347,8 +355,12 @@ class HloModuleConfig { const std::vector>& phase_ordering_config() const { return phase_ordering_config_; } - std::vector>* mutable_phase_ordering_config() { - return &phase_ordering_config_; + void set_phase_ordering_config( + std::vector> phase_ordering_config) { + phase_ordering_config_ = std::move(phase_ordering_config); + } + std::vector>& mutable_phase_ordering_config() { + return phase_ordering_config_; } int phase_index() const { return phase_index_; } @@ -398,7 +410,9 @@ class HloModuleConfig { } absl::string_view fdo_profile() const { return fdo_profile_; } - std::string* mutable_fdo_profile() { return &fdo_profile_; } + void set_fdo_profile(absl::string_view fdo_profile) { + fdo_profile_ = fdo_profile; + } int64_t device_memory_size() const { return device_memory_size_; } void set_device_memory_size(int64_t device_memory_size) { diff --git a/third_party/xla/xla/service/hlo_module_config_test.cc b/third_party/xla/xla/service/hlo_module_config_test.cc index 8bcd6d5ed89391..3f3703ff8fe467 100644 --- a/third_party/xla/xla/service/hlo_module_config_test.cc +++ b/third_party/xla/xla/service/hlo_module_config_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "xla/tests/test_utils.h" #include "xla/xla.pb.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_module_util.cc b/third_party/xla/xla/service/hlo_module_util.cc index 1bc65eef147ef9..2d2fb104c5e3e3 100644 --- a/third_party/xla/xla/service/hlo_module_util.cc +++ b/third_party/xla/xla/service/hlo_module_util.cc @@ -15,18 +15,27 @@ limitations under the License. #include "xla/service/hlo_module_util.h" +#include #include #include #include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" +#include "xla/service/computation_layout.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" +#include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -134,7 +143,7 @@ absl::StatusOr> CreateModuleConfig( } config->set_alias_passthrough_params( execution_options->alias_passthrough_params()); - *config->mutable_fdo_profile() = execution_options->fdo_profile(); + config->set_fdo_profile(execution_options->fdo_profile()); config->set_device_memory_size(execution_options->device_memory_size()); config->set_use_shardy_partitioner( execution_options->use_shardy_partitioner()); @@ -154,7 +163,7 @@ absl::StatusOr> CreateModuleConfig( FusionConfigCollection::kOff) { config->set_fusion_config_collection( aot_options->fusion_config_collection()); - *config->mutable_fusion_config() = aot_options->fusion_config(); + config->set_fusion_config(aot_options->fusion_config()); } } diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index 3dd0cb630168e0..b68a1614b1bd8f 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -700,8 +700,7 @@ absl::StatusOr InstructionFusion::Run( VLOG(1) << "There are " << fused_count << " fused bits that cause " << fuse_count << " fusion actions."; } - *module->mutable_config().mutable_fusion_config() = - std::move(fusion_config); + module->mutable_config().set_fusion_config(std::move(fusion_config)); } VLOG(1) << "Fusion count: " << fuse_count;