Skip to content

Commit

Permalink
[XLA:GPU] Add coalescing heuristic.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606537304
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Feb 13, 2024
1 parent 9a39952 commit f43f8c3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ cc_library(
hdrs = ["gpu_indexing_performance_model.h"],
visibility = ["//visibility:public"],
deps = [
":coalescing_analysis",
":gpu_hlo_cost_analysis",
":gpu_performance_model_base",
":hlo_op_profiles",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/coalescing_analysis.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model_base.h"
#include "xla/service/gpu/model/indexing_analysis.h"
Expand Down Expand Up @@ -79,7 +80,7 @@ int64_t GetIterationSpaceSize(const IndexingMap& indexing_map,

EstimateRunTimeData
GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion(
const HloFusionAnalysis& fusion_analysis) {
const HloFusionAnalysis& fusion_analysis, bool is_coalesced) {
auto& fusion_adaptor = fusion_analysis.fusion();
auto roots = fusion_adaptor.GetRoots();
CHECK_EQ(roots.size(), 1)
Expand Down Expand Up @@ -126,9 +127,9 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion(
int64_t n_bytes_net = shape_size_(instr->shape());
auto element_type = instr->shape().element_type();

read_time += ReadTimeWithDRAMHeuristic(*device_info_, num_blocks,
n_bytes_net, n_bytes_total,
element_type, /*coalesced=*/true);
read_time +=
ReadTimeWithDRAMHeuristic(*device_info_, num_blocks, n_bytes_net,
n_bytes_total, element_type, is_coalesced);
}
}

Expand All @@ -155,7 +156,8 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForInstruction(

auto fusion_analysis = AnalyzeFusion(*producer, *device_info_);

return EstimateRunTimeForFusion(fusion_analysis);
bool is_coalesced = IsReadCoalescedHeuristic(fusion_analysis, producer);
return EstimateRunTimeForFusion(fusion_analysis, is_coalesced);
}

EstimateRunTimeData
Expand All @@ -164,7 +166,9 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer(
auto fusion_analysis =
AnalyzeProducerConsumerFusion(*producer, *consumer, *device_info_);

return EstimateRunTimeForFusion(fusion_analysis);
bool is_coalesced =
IsReadCoalescedHeuristic(fusion_analysis, producer, consumer);
return EstimateRunTimeForFusion(fusion_analysis, is_coalesced);
}

/*static*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
mlir_context_(mlir_context) {}

EstimateRunTimeData EstimateRunTimeForFusion(
const HloFusionAnalysis& fusion_analysis);
const HloFusionAnalysis& fusion_analysis, bool is_coalesced = true);

EstimateRunTimeData EstimateRunTimeForInstruction(
const HloInstruction* producer);
Expand Down

0 comments on commit f43f8c3

Please sign in to comment.