Skip to content

Commit

Permalink
Add the ability to run passes on a single function. (#441)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Aug 19, 2024
1 parent 818b082 commit 22a1d59
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 17 deletions.
6 changes: 6 additions & 0 deletions deps/LLVMExtra/include/LLVMExtra.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,17 @@ void LLVMPassBuilderExtensionsRegisterFunctionPass(LLVMPassBuilderExtensionsRef
const char *PassName,
LLVMJuliaFunctionPassCallback Callback,
void *Thunk);
#if LLVM_VERSION_MAJOR < 20
void LLVMPassBuilderExtensionsSetAAPipeline(LLVMPassBuilderExtensionsRef Extensions,
const char *AAPipeline);
#endif
LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions);
LLVMErrorRef LLVMRunJuliaPassesOnFunction(LLVMValueRef F, const char *Passes,
LLVMTargetMachineRef TM,
LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions);

LLVMValueRef LLVMBuildAtomicRMWSyncScope(LLVMBuilderRef B,LLVMAtomicRMWBinOp op,
LLVMValueRef PTR, LLVMValueRef Val,
Expand Down
53 changes: 39 additions & 14 deletions deps/LLVMExtra/lib/NewPM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,12 @@ void LLVMPassBuilderExtensionsSetAAPipeline(LLVMPassBuilderExtensionsRef Extensi

// Vendored API entrypoint

LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions) {
TargetMachine *Machine = unwrap(TM);
LLVMPassBuilderOptions *PassOpts = unwrap(Options);
LLVMPassBuilderExtensions *PassExts = unwrap(Extensions);
static LLVMErrorRef runJuliaPasses(Module *Mod, Function *Fun, const char *Passes,
TargetMachine *Machine, LLVMPassBuilderOptions *PassOpts,
LLVMPassBuilderExtensions *PassExts) {
bool Debug = PassOpts->DebugLogging;
bool VerifyEach = PassOpts->VerifyEach;

Module *Mod = unwrap(M);
PassInstrumentationCallbacks PIC;
#if LLVM_VERSION_MAJOR >= 16
PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC);
Expand Down Expand Up @@ -203,14 +199,43 @@ LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
#else
SI.registerCallbacks(PIC, &FAM);
#endif
ModulePassManager MPM;
if (VerifyEach) {
MPM.addPass(VerifierPass());
}
if (auto Err = PB.parsePassPipeline(MPM, Passes)) {
return wrap(std::move(Err));

if (Fun) {
FunctionPassManager FPM;
if (VerifyEach)
FPM.addPass(VerifierPass());
if (auto Err = PB.parsePassPipeline(FPM, Passes))
return wrap(std::move(Err));
FPM.run(*Fun, FAM);
} else {
ModulePassManager MPM;
if (VerifyEach)
MPM.addPass(VerifierPass());
if (auto Err = PB.parsePassPipeline(MPM, Passes))
return wrap(std::move(Err));
MPM.run(*Mod, MAM);
}

MPM.run(*Mod, MAM);
return LLVMErrorSuccess;
}

LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions) {
TargetMachine *Machine = unwrap(TM);
LLVMPassBuilderOptions *PassOpts = unwrap(Options);
LLVMPassBuilderExtensions *PassExts = unwrap(Extensions);
Module *Mod = unwrap(M);
return runJuliaPasses(Mod, nullptr, Passes, Machine, PassOpts, PassExts);
}

LLVMErrorRef LLVMRunJuliaPassesOnFunction(LLVMValueRef F, const char *Passes,
LLVMTargetMachineRef TM,
LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions) {
TargetMachine *Machine = unwrap(TM);
LLVMPassBuilderOptions *PassOpts = unwrap(Options);
LLVMPassBuilderExtensions *PassExts = unwrap(Extensions);
Function *Fun = unwrap<Function>(F);
return runJuliaPasses(Fun->getParent(), Fun, Passes, Machine, PassOpts, PassExts);
}
4 changes: 4 additions & 0 deletions lib/15/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/16/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/17/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/18/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
11 changes: 8 additions & 3 deletions src/newpm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ represents a pass pipeline. The target machine is used to optimize the passes.
"""
run!

function run!(pb::NewPMPassBuilder, mod::Module, tm::Union{Nothing,TargetMachine}=nothing)
function run!(pb::NewPMPassBuilder, target::Union{Module,Function}, tm::Union{Nothing,TargetMachine}=nothing)
isempty(pb.passes) && return
pipeline = join(pb.passes, ",")
aa_pipeline = join(pb.aa_passes, ",")
Expand Down Expand Up @@ -286,8 +286,13 @@ function run!(pb::NewPMPassBuilder, mod::Module, tm::Union{Nothing,TargetMachine
end
end

@check API.LLVMRunJuliaPasses(mod, pipeline, something(tm, C_NULL),
pb.opts, pb.exts)
if target isa Module
@check API.LLVMRunJuliaPasses(target, pipeline, something(tm, C_NULL),
pb.opts, pb.exts)
elseif target isa Function
@check API.LLVMRunJuliaPassesOnFunction(target, pipeline, something(tm, C_NULL),
pb.opts, pb.exts)
end
end
end

Expand Down
5 changes: 5 additions & 0 deletions test/newpm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ end
@dispose ctx=Context() begin
# single pass
@dispose mod=test_module() begin
fun = only(functions(mod))

# by string
@test run!("no-op-module", mod) === nothing
@test run!("no-op-function", fun) === nothing

# by object
@test run!(NoOpModulePass(), mod) === nothing
@test run!(NoOpFunctionPass(), fun) === nothing

# by object with options
@test run!(LoopExtractorPass(; single=true), mod) === nothing
@test run!(EarlyCSEPass(; memssa=true), fun) === nothing
end

# default pipelines
Expand Down

0 comments on commit 22a1d59

Please sign in to comment.