diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 336069e8c9baf5..8d05c8ce160edb 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -500,6 +500,7 @@ cc_library( "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorToLLVM", + "@llvm-project//mlir:VectorToSCF", "@llvm-project//mlir:VectorTransforms", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index 5d96212cfba1d2..c6e7792e1cfe0a 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" // from @llvm-project @@ -292,6 +293,7 @@ static Status CreateHloXlaPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(mlir::createConvertVectorToSCFPass()); pm.addNestedPass(xla::cpu::createLegalizeI1VectorTransferOpsPass()); pm.addNestedPass( xla::cpu::createConvertXlaCpuMemRefElementCastToLLVMPass());