diff --git a/deps/LLVMExtra/include/LLVMExtra.h b/deps/LLVMExtra/include/LLVMExtra.h index 453dbfec..5a00faec 100644 --- a/deps/LLVMExtra/include/LLVMExtra.h +++ b/deps/LLVMExtra/include/LLVMExtra.h @@ -254,11 +254,17 @@ void LLVMPassBuilderExtensionsRegisterFunctionPass(LLVMPassBuilderExtensionsRef const char *PassName, LLVMJuliaFunctionPassCallback Callback, void *Thunk); +#if LLVM_VERSION_MAJOR < 20 void LLVMPassBuilderExtensionsSetAAPipeline(LLVMPassBuilderExtensionsRef Extensions, const char *AAPipeline); +#endif LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes, LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options, LLVMPassBuilderExtensionsRef Extensions); +LLVMErrorRef LLVMRunJuliaPassesOnFunction(LLVMValueRef F, const char *Passes, + LLVMTargetMachineRef TM, + LLVMPassBuilderOptionsRef Options, + LLVMPassBuilderExtensionsRef Extensions); LLVMValueRef LLVMBuildAtomicRMWSyncScope(LLVMBuilderRef B,LLVMAtomicRMWBinOp op, LLVMValueRef PTR, LLVMValueRef Val, diff --git a/deps/LLVMExtra/lib/NewPM.cpp b/deps/LLVMExtra/lib/NewPM.cpp index 0bed1f57..2b765698 100644 --- a/deps/LLVMExtra/lib/NewPM.cpp +++ b/deps/LLVMExtra/lib/NewPM.cpp @@ -145,16 +145,12 @@ void LLVMPassBuilderExtensionsSetAAPipeline(LLVMPassBuilderExtensionsRef Extensi // Vendored API entrypoint -LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes, - LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options, - LLVMPassBuilderExtensionsRef Extensions) { - TargetMachine *Machine = unwrap(TM); - LLVMPassBuilderOptions *PassOpts = unwrap(Options); - LLVMPassBuilderExtensions *PassExts = unwrap(Extensions); +static LLVMErrorRef runJuliaPasses(Module *Mod, Function *Fun, const char *Passes, + TargetMachine *Machine, LLVMPassBuilderOptions *PassOpts, + LLVMPassBuilderExtensions *PassExts) { bool Debug = PassOpts->DebugLogging; bool VerifyEach = PassOpts->VerifyEach; - Module *Mod = unwrap(M); PassInstrumentationCallbacks PIC; #if LLVM_VERSION_MAJOR >= 16 PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC); @@ -203,14 +199,43 @@ LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes, #else SI.registerCallbacks(PIC, &FAM); #endif - ModulePassManager MPM; - if (VerifyEach) { - MPM.addPass(VerifierPass()); - } - if (auto Err = PB.parsePassPipeline(MPM, Passes)) { - return wrap(std::move(Err)); + + if (Fun) { + FunctionPassManager FPM; + if (VerifyEach) + FPM.addPass(VerifierPass()); + if (auto Err = PB.parsePassPipeline(FPM, Passes)) + return wrap(std::move(Err)); + FPM.run(*Fun, FAM); + } else { + ModulePassManager MPM; + if (VerifyEach) + MPM.addPass(VerifierPass()); + if (auto Err = PB.parsePassPipeline(MPM, Passes)) + return wrap(std::move(Err)); + MPM.run(*Mod, MAM); } - MPM.run(*Mod, MAM); return LLVMErrorSuccess; } + +LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes, + LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options, + LLVMPassBuilderExtensionsRef Extensions) { + TargetMachine *Machine = unwrap(TM); + LLVMPassBuilderOptions *PassOpts = unwrap(Options); + LLVMPassBuilderExtensions *PassExts = unwrap(Extensions); + Module *Mod = unwrap(M); + return runJuliaPasses(Mod, nullptr, Passes, Machine, PassOpts, PassExts); +} + +LLVMErrorRef LLVMRunJuliaPassesOnFunction(LLVMValueRef F, const char *Passes, + LLVMTargetMachineRef TM, + LLVMPassBuilderOptionsRef Options, + LLVMPassBuilderExtensionsRef Extensions) { + TargetMachine *Machine = unwrap(TM); + LLVMPassBuilderOptions *PassOpts = unwrap(Options); + LLVMPassBuilderExtensions *PassExts = unwrap(Extensions); + Function *Fun = unwrap(F); + return runJuliaPasses(Fun->getParent(), Fun, Passes, Machine, PassOpts, PassExts); +} diff --git a/lib/15/libLLVM_extra.jl b/lib/15/libLLVM_extra.jl index 45eeb9fb..180545fc 100644 --- a/lib/15/libLLVM_extra.jl +++ b/lib/15/libLLVM_extra.jl @@ -385,6 +385,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions) ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions) end +function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions) + ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions) +end + function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope) ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope) end diff --git a/lib/16/libLLVM_extra.jl b/lib/16/libLLVM_extra.jl index 45eeb9fb..180545fc 100644 --- a/lib/16/libLLVM_extra.jl +++ b/lib/16/libLLVM_extra.jl @@ -385,6 +385,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions) ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions) end +function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions) + ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions) +end + function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope) ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope) end diff --git a/lib/17/libLLVM_extra.jl b/lib/17/libLLVM_extra.jl index 69290063..74bcb0f0 100644 --- a/lib/17/libLLVM_extra.jl +++ b/lib/17/libLLVM_extra.jl @@ -345,6 +345,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions) ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions) end +function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions) + ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions) +end + function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope) ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope) end diff --git a/lib/18/libLLVM_extra.jl b/lib/18/libLLVM_extra.jl index d0ba92a4..04629305 100644 --- a/lib/18/libLLVM_extra.jl +++ b/lib/18/libLLVM_extra.jl @@ -279,6 +279,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions) ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions) end +function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions) + ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions) +end + function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope) ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope) end diff --git a/src/newpm.jl b/src/newpm.jl index de3279bb..e63b399b 100644 --- a/src/newpm.jl +++ b/src/newpm.jl @@ -247,7 +247,7 @@ represents a pass pipeline. The target machine is used to optimize the passes. """ run! -function run!(pb::NewPMPassBuilder, mod::Module, tm::Union{Nothing,TargetMachine}=nothing) +function run!(pb::NewPMPassBuilder, target::Union{Module,Function}, tm::Union{Nothing,TargetMachine}=nothing) isempty(pb.passes) && return pipeline = join(pb.passes, ",") aa_pipeline = join(pb.aa_passes, ",") @@ -286,8 +286,13 @@ function run!(pb::NewPMPassBuilder, mod::Module, tm::Union{Nothing,TargetMachine end end - @check API.LLVMRunJuliaPasses(mod, pipeline, something(tm, C_NULL), - pb.opts, pb.exts) + if target isa Module + @check API.LLVMRunJuliaPasses(target, pipeline, something(tm, C_NULL), + pb.opts, pb.exts) + elseif target isa Function + @check API.LLVMRunJuliaPassesOnFunction(target, pipeline, something(tm, C_NULL), + pb.opts, pb.exts) + end end end diff --git a/test/newpm_tests.jl b/test/newpm_tests.jl index 42fc4e4e..05d5b06d 100644 --- a/test/newpm_tests.jl +++ b/test/newpm_tests.jl @@ -23,14 +23,19 @@ end @dispose ctx=Context() begin # single pass @dispose mod=test_module() begin + fun = only(functions(mod)) + # by string @test run!("no-op-module", mod) === nothing + @test run!("no-op-function", fun) === nothing # by object @test run!(NoOpModulePass(), mod) === nothing + @test run!(NoOpFunctionPass(), fun) === nothing # by object with options @test run!(LoopExtractorPass(; single=true), mod) === nothing + @test run!(EarlyCSEPass(; memssa=true), fun) === nothing end # default pipelines