Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to run passes on a single function. #441

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deps/LLVMExtra/include/LLVMExtra.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,17 @@ void LLVMPassBuilderExtensionsRegisterFunctionPass(LLVMPassBuilderExtensionsRef
const char *PassName,
LLVMJuliaFunctionPassCallback Callback,
void *Thunk);
#if LLVM_VERSION_MAJOR < 20
void LLVMPassBuilderExtensionsSetAAPipeline(LLVMPassBuilderExtensionsRef Extensions,
const char *AAPipeline);
#endif
LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions);
LLVMErrorRef LLVMRunJuliaPassesOnFunction(LLVMValueRef F, const char *Passes,
LLVMTargetMachineRef TM,
LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions);

LLVMValueRef LLVMBuildAtomicRMWSyncScope(LLVMBuilderRef B,LLVMAtomicRMWBinOp op,
LLVMValueRef PTR, LLVMValueRef Val,
Expand Down
53 changes: 39 additions & 14 deletions deps/LLVMExtra/lib/NewPM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,12 @@ void LLVMPassBuilderExtensionsSetAAPipeline(LLVMPassBuilderExtensionsRef Extensi

// Vendored API entrypoint

LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions) {
TargetMachine *Machine = unwrap(TM);
LLVMPassBuilderOptions *PassOpts = unwrap(Options);
LLVMPassBuilderExtensions *PassExts = unwrap(Extensions);
static LLVMErrorRef runJuliaPasses(Module *Mod, Function *Fun, const char *Passes,
TargetMachine *Machine, LLVMPassBuilderOptions *PassOpts,
LLVMPassBuilderExtensions *PassExts) {
bool Debug = PassOpts->DebugLogging;
bool VerifyEach = PassOpts->VerifyEach;

Module *Mod = unwrap(M);
PassInstrumentationCallbacks PIC;
#if LLVM_VERSION_MAJOR >= 16
PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC);
Expand Down Expand Up @@ -203,14 +199,43 @@ LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
#else
SI.registerCallbacks(PIC, &FAM);
#endif
ModulePassManager MPM;
if (VerifyEach) {
MPM.addPass(VerifierPass());
}
if (auto Err = PB.parsePassPipeline(MPM, Passes)) {
return wrap(std::move(Err));

if (Fun) {
FunctionPassManager FPM;
if (VerifyEach)
FPM.addPass(VerifierPass());
if (auto Err = PB.parsePassPipeline(FPM, Passes))
return wrap(std::move(Err));
FPM.run(*Fun, FAM);
} else {
ModulePassManager MPM;
if (VerifyEach)
MPM.addPass(VerifierPass());
if (auto Err = PB.parsePassPipeline(MPM, Passes))
return wrap(std::move(Err));
MPM.run(*Mod, MAM);
}

MPM.run(*Mod, MAM);
return LLVMErrorSuccess;
}

LLVMErrorRef LLVMRunJuliaPasses(LLVMModuleRef M, const char *Passes,
LLVMTargetMachineRef TM, LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions) {
TargetMachine *Machine = unwrap(TM);
LLVMPassBuilderOptions *PassOpts = unwrap(Options);
LLVMPassBuilderExtensions *PassExts = unwrap(Extensions);
Module *Mod = unwrap(M);
return runJuliaPasses(Mod, nullptr, Passes, Machine, PassOpts, PassExts);
}

LLVMErrorRef LLVMRunJuliaPassesOnFunction(LLVMValueRef F, const char *Passes,
LLVMTargetMachineRef TM,
LLVMPassBuilderOptionsRef Options,
LLVMPassBuilderExtensionsRef Extensions) {
TargetMachine *Machine = unwrap(TM);
LLVMPassBuilderOptions *PassOpts = unwrap(Options);
LLVMPassBuilderExtensions *PassExts = unwrap(Extensions);
Function *Fun = unwrap<Function>(F);
return runJuliaPasses(Fun->getParent(), Fun, Passes, Machine, PassOpts, PassExts);
}
4 changes: 4 additions & 0 deletions lib/15/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/16/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/17/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/18/libLLVM_extra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ function LLVMRunJuliaPasses(M, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPasses, libLLVMExtra), LLVMErrorRef, (LLVMModuleRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), M, Passes, TM, Options, Extensions)
end

function LLVMRunJuliaPassesOnFunction(F, Passes, TM, Options, Extensions)
ccall((:LLVMRunJuliaPassesOnFunction, libLLVMExtra), LLVMErrorRef, (LLVMValueRef, Cstring, LLVMTargetMachineRef, LLVMPassBuilderOptionsRef, LLVMPassBuilderExtensionsRef), F, Passes, TM, Options, Extensions)
end

function LLVMBuildAtomicRMWSyncScope(B, op, PTR, Val, ordering, syncscope)
ccall((:LLVMBuildAtomicRMWSyncScope, libLLVMExtra), LLVMValueRef, (LLVMBuilderRef, LLVMAtomicRMWBinOp, LLVMValueRef, LLVMValueRef, LLVMAtomicOrdering, Cstring), B, op, PTR, Val, ordering, syncscope)
end
Expand Down
11 changes: 8 additions & 3 deletions src/newpm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ represents a pass pipeline. The target machine is used to optimize the passes.
"""
run!

function run!(pb::NewPMPassBuilder, mod::Module, tm::Union{Nothing,TargetMachine}=nothing)
function run!(pb::NewPMPassBuilder, target::Union{Module,Function}, tm::Union{Nothing,TargetMachine}=nothing)
isempty(pb.passes) && return
pipeline = join(pb.passes, ",")
aa_pipeline = join(pb.aa_passes, ",")
Expand Down Expand Up @@ -286,8 +286,13 @@ function run!(pb::NewPMPassBuilder, mod::Module, tm::Union{Nothing,TargetMachine
end
end

@check API.LLVMRunJuliaPasses(mod, pipeline, something(tm, C_NULL),
pb.opts, pb.exts)
if target isa Module
@check API.LLVMRunJuliaPasses(target, pipeline, something(tm, C_NULL),
pb.opts, pb.exts)
elseif target isa Function
@check API.LLVMRunJuliaPassesOnFunction(target, pipeline, something(tm, C_NULL),
pb.opts, pb.exts)
end
end
end

Expand Down
5 changes: 5 additions & 0 deletions test/newpm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ end
@dispose ctx=Context() begin
# single pass
@dispose mod=test_module() begin
fun = only(functions(mod))

# by string
@test run!("no-op-module", mod) === nothing
@test run!("no-op-function", fun) === nothing

# by object
@test run!(NoOpModulePass(), mod) === nothing
@test run!(NoOpFunctionPass(), fun) === nothing

# by object with options
@test run!(LoopExtractorPass(; single=true), mod) === nothing
@test run!(EarlyCSEPass(; memssa=true), fun) === nothing
end

# default pipelines
Expand Down
Loading