Skip to content

Commit

Permalink
Add a vector->SCF pass to hlo_xla_runtime_pipeline.
Browse files Browse the repository at this point in the history
Without this pass some of the vector.transfers are not unrolled/converted and
the pipeline can fail.

PiperOrigin-RevId: 585893540
  • Loading branch information
pifon2a authored and tensorflower-gardener committed Nov 28, 2023
1 parent b37d189 commit d4fc2ec
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ cc_library(
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
"@llvm-project//mlir:VectorTransforms",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" // from @llvm-project
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" // from @llvm-project
Expand Down Expand Up @@ -292,6 +293,7 @@ static Status CreateHloXlaPipeline(

pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<FuncOp>(mlir::createConvertVectorToSCFPass());
pm.addNestedPass<FuncOp>(xla::cpu::createLegalizeI1VectorTransferOpsPass());
pm.addNestedPass<FuncOp>(
xla::cpu::createConvertXlaCpuMemRefElementCastToLLVMPass());
Expand Down

0 comments on commit d4fc2ec

Please sign in to comment.