Skip to content

Commit

Permalink
[Truncate] Handle libm calls (#1636)
Browse files Browse the repository at this point in the history
* Handle instrinsics

* Add integration test

* Add common handler for math intrinsics

* Format EnzymeLogic.cpp
  • Loading branch information
ivanradanov authored Feb 8, 2024
1 parent 21053d4 commit bd36bae
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 23 deletions.
54 changes: 31 additions & 23 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
//===----------------------------------------------------------------------===//
#include "ActivityAnalysis.h"
#include "AdjointGenerator.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Intrinsics.h"

#if LLVM_VERSION_MAJOR >= 16
Expand Down Expand Up @@ -5075,46 +5076,47 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
return;
}
void visitFenceInst(llvm::FenceInst &FI) { return; }
void visitIntrinsicInst(llvm::IntrinsicInst &II) {
SmallVector<Value *, 2> orig_ops(II.arg_size());
for (unsigned i = 0; i < II.arg_size(); ++i)
orig_ops[i] = II.getOperand(i);
if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops))
return;

bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) {
SmallVector<Value *, 2> orig_ops(CI.arg_size());
for (unsigned i = 0; i < CI.arg_size(); ++i)
orig_ops[i] = CI.getOperand(i);

bool hasFromType = false;
auto newI = cast<llvm::IntrinsicInst>(getNewFromOriginal(&II));
auto newI = cast<llvm::CallInst>(getNewFromOriginal(&CI));
IRBuilder<> B(newI);
SmallVector<Value *, 2> new_ops(II.arg_size());
for (unsigned i = 0; i < II.arg_size(); ++i) {
SmallVector<Value *, 2> new_ops(CI.arg_size());
for (unsigned i = 0; i < CI.arg_size(); ++i) {
if (orig_ops[i]->getType() == getFromType()) {
new_ops[i] = truncate(B, getNewFromOriginal(orig_ops[i]));
hasFromType = true;
} else {
new_ops[i] = getNewFromOriginal(orig_ops[i]);
}
}
Type *retTy = II.getType();
if (II.getType() == getFromType()) {
Type *retTy = CI.getType();
if (CI.getType() == getFromType()) {
hasFromType = true;
retTy = getToType();
}

if (!hasFromType)
return;
return false;

// TODO check that the intrinsic is overloaded

CallInst *intr;
Value *nres = intr = createIntrinsicCall(B, II.getIntrinsicID(), retTy,
new_ops, &II, II.getName());
if (II.getType() == getFromType())
Value *nres = intr =
createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName());
if (CI.getType() == getFromType())
nres = expand(B, nres);
intr->copyIRFlags(newI);
newI->replaceAllUsesWith(nres);
newI->eraseFromParent();

return;
return true;
}
void visitIntrinsicInst(llvm::IntrinsicInst &II) {
handleIntrinsic(II, II.getIntrinsicID());
}

void visitReturnInst(llvm::ReturnInst &I) { return; }
Expand Down Expand Up @@ -5201,18 +5203,24 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
return v;
}
// Return
void visitCallInst(llvm::CallInst &call) {
void visitCallInst(llvm::CallInst &CI) {
Intrinsic::ID ID;
StringRef funcName = getFuncNameFromCall(const_cast<CallInst *>(&CI));
if (isMemFreeLibMFunction(funcName, &ID))
if (handleIntrinsic(CI, ID))
return;

using namespace llvm;

CallInst *const newCall = cast<CallInst>(getNewFromOriginal(&call));
CallInst *const newCall = cast<CallInst>(getNewFromOriginal(&CI));
IRBuilder<> BuilderZ(newCall);

if (auto called = call.getCalledFunction())
if (handleKnownCalls(call, called, getFuncNameFromCall(&call), newCall))
if (auto called = CI.getCalledFunction())
if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall))
return;

RequestContext ctx(&call, &BuilderZ);
auto val = GetShadow(ctx, getNewFromOriginal(call.getCalledOperand()));
RequestContext ctx(&CI, &BuilderZ);
auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand()));
newCall->setCalledOperand(val);
return;
}
Expand Down
1 change: 1 addition & 0 deletions enzyme/test/Integration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(ForwardModeVector)
add_subdirectory(ReverseMode)
add_subdirectory(BatchMode)
add_subdirectory(Sparse)
add_subdirectory(Truncate)

# Run regression and unit tests
add_lit_testsuite(check-enzyme-integration "Running enzyme integration tests"
Expand Down
9 changes: 9 additions & 0 deletions enzyme/test/Integration/Truncate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Run regression and unit tests
add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme batch mode integration tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${ENZYME_TEST_DEPS}
ARGS -v
)

set_target_properties(check-enzyme-integration-truncate PROPERTIES FOLDER "Tests")

111 changes: 111 additions & 0 deletions enzyme/test/Integration/Truncate/simple.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// COM: %clang -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli -
// RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli -
// COM: %clang -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli -
// COM: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli -

#include <math.h>

#include "../test_utils.h"

#define N 10

double simple_add(double a, double b) {
return a + b;
}
double intrinsics(double a, double b) {
return sqrt(a) * pow(b, 2);
}
// TODO
double constt(double a, double b) {
return 2;
}
double compute(double *A, double *B, double *C, int n) {
for (int i = 0; i < n; i++) {
C[i] = A[i] * 2;
}
return C[0];
}

typedef double (*fty)(double *, double *, double *, int);

typedef double (*fty2)(double, double);

extern fty __enzyme_truncate_func_2(...);
extern fty2 __enzyme_truncate_func(...);
extern double __enzyme_truncate_value(...);
extern double __enzyme_expand_value(...);

#define FROM 64
#define TO 32

#define TEST(F) do {


int main() {

{
double a = 1;
APPROX_EQ(
__enzyme_expand_value(
__enzyme_truncate_value(a, FROM, TO) , FROM, TO),
a, 1e-10);
}

{
double a = 2;
double b = 3;
double truth = simple_add(a, b);
a = __enzyme_truncate_value(a, FROM, TO);
b = __enzyme_truncate_value(b, FROM, TO);
double trunc = __enzyme_expand_value(__enzyme_truncate_func(simple_add, FROM, TO)(a, b), FROM, TO);
APPROX_EQ(trunc, truth, 1e-5);
}
{
double a = 2;
double b = 3;
double truth = intrinsics(a, b);
a = __enzyme_truncate_value(a, FROM, TO);
b = __enzyme_truncate_value(b, FROM, TO);
double trunc = __enzyme_expand_value(__enzyme_truncate_func(intrinsics, FROM, TO)(a, b), FROM, TO);
APPROX_EQ(trunc, truth, 1e-5);
}
// {
// double a = 2;
// double b = 3;
// double truth = intrinsics(a, b);
// a = __enzyme_truncate_value(a, FROM, TO);
// b = __enzyme_truncate_value(b, FROM, TO);
// double trunc = __enzyme_expand_value(__enzyme_truncate_func(constt, FROM, TO)(a, b), FROM, TO);
// APPROX_EQ(trunc, truth, 1e-5);
// }

// double A[N];
// double B[N];
// double C[N];
// double D[N];


// for (int i = 0; i < N; i++) {
// A[i] = 1 + i % 5;
// B[i] = 1 + i % 3;
// }

// compute(A, B, D, N);

// for (int i = 0; i < N; i++) {
// A[i] = __enzyme_truncate_value(A[i], 64, 32);
// B[i] = __enzyme_truncate_value(B[i], 64, 32);
// }

// __enzyme_truncate_func_2(compute, 64, 32)(A, B, C, N);

// for (int i = 0; i < N; i++) {
// C[i] = __enzyme_expand_value(C[i], 64, 32);
// }

// for (int i = 0; i < N; i++) {
// printf("%d\n", i);
// APPROX_EQ(D[i], C[i], 1e-5);
// }

}

0 comments on commit bd36bae

Please sign in to comment.