diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 03213fc081b3..ddac330fe7c2 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1084,8 +1084,14 @@ class AdjointGenerator auto dt = vd[{-1}]; for (size_t i = start; i < size; ++i) { + auto nex = vd[{(int)i}]; + if ((nex == BaseType::Anything && dt.isFloat()) || + (dt == BaseType::Anything && nex.isFloat())) { + nextStart = i; + break; + } bool Legal = true; - dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); + dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal); if (!Legal) { nextStart = i; break; @@ -1199,7 +1205,8 @@ class AdjointGenerator Builder2, align, start, size, isVolatile, ordering, syncScope, mask, prevNoAlias, prevScopes); ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_val, diff, Builder2, FT, start, size, mask); + ->addToDiffe(orig_val, diff, Builder2, FT, start, size, {}, + mask); } break; } @@ -1909,56 +1916,80 @@ class AdjointGenerator if (!gutils->isConstantValue(orig_inserted)) { auto TT = TR.query(orig_inserted); - auto it = TT[{-1}]; - bool Legal = true; - for (size_t i = 0; i < size0; ++i) { - bool LegalOr = true; - it.checkedOrIn(TT[{(int)i}], /*pointerIntSame*/ true, LegalOr); - Legal &= LegalOr; - } - Type *flt = it.isFloat(); - if (!it.isKnown() || !Legal) { - bool found = false; - - if (looseTypeAnalysis && !Legal) { - if (orig_inserted->getType()->isFPOrFPVectorTy()) { - flt = orig_inserted->getType()->getScalarType(); - found = true; - } else if (orig_inserted->getType()->isIntOrIntVectorTy() || - orig_inserted->getType()->isPointerTy()) { - flt = nullptr; - found = true; + + unsigned start = 0; + Value *dindex = nullptr; + + while (1) { + unsigned nextStart = size0; + + auto dt = TT[{-1}]; + for (size_t i = start; i < size0; ++i) { + auto nex = TT[{(int)i}]; + if ((nex == BaseType::Anything && dt.isFloat()) || + (dt == BaseType::Anything && nex.isFloat())) { + nextStart = i; + break; + } + bool Legal = true; + dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal); + if (!Legal) { + nextStart = i; + break; } } - if (!found) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of insertvalue " << IVI - << " size: " << size0 << " TT: " << TT.str(); - if (CustomErrorHandler) { - CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType, - &TR.analyzer, nullptr, wrap(&Builder2)); - } else { - EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI, - ss.str()); + Type *flt = dt.isFloat(); + if (!dt.isKnown()) { + bool found = false; + if (looseTypeAnalysis) { + if (orig_inserted->getType()->isFPOrFPVectorTy()) { + flt = orig_inserted->getType()->getScalarType(); + found = true; + } else if (orig_inserted->getType()->isIntOrIntVectorTy() || + orig_inserted->getType()->isPointerTy()) { + flt = nullptr; + found = true; + } + } + if (!found) { + std::string str; + raw_string_ostream ss(str); + ss << "Cannot deduce type of insertvalue ins " << IVI + << " size: " << size0 << " TT: " << TT.str(); + if (CustomErrorHandler) { + CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType, + &TR.analyzer, nullptr, wrap(&Builder2)); + } else { + EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI, + ss.str()); + } } } - } - if (flt) { - auto rule = [&](Value *prediff) { - return Builder2.CreateExtractValue(prediff, IVI.getIndices()); - }; - auto prediff = diffe(&IVI, Builder2); - auto dindex = - applyChainRule(orig_inserted->getType(), Builder2, rule, prediff); - addToDiffe(orig_inserted, dindex, Builder2, flt); + + if (flt) { + if (!dindex) { + auto rule = [&](Value *prediff) { + return Builder2.CreateExtractValue(prediff, IVI.getIndices()); + }; + auto prediff = diffe(&IVI, Builder2); + dindex = applyChainRule(orig_inserted->getType(), Builder2, rule, + prediff); + } + + auto TT = TR.query(orig_inserted); + + ((DiffeGradientUtils *)gutils) + ->addToDiffe(orig_inserted, dindex, Builder2, flt, start, + nextStart - start); + } + if (nextStart == size0) + break; + start = nextStart; } } size_t size1 = 1; - if (orig_agg->getType()->isSized() && - (orig_agg->getType()->isIntOrIntVectorTy() || - orig_agg->getType()->isFPOrFPVectorTy())) + if (orig_agg->getType()->isSized()) size1 = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( orig_agg->getType()) + @@ -1966,15 +1997,79 @@ class AdjointGenerator 8; if (!gutils->isConstantValue(orig_agg)) { - auto rule = [&](Value *prediff) { - return Builder2.CreateInsertValue( - prediff, Constant::getNullValue(orig_inserted->getType()), - IVI.getIndices()); - }; - auto prediff = diffe(&IVI, Builder2); - auto dindex = - applyChainRule(orig_agg->getType(), Builder2, rule, prediff); - addToDiffe(orig_agg, dindex, Builder2, TR.addingType(size1, orig_agg)); + + auto TT = TR.query(orig_agg); + + unsigned start = 0; + + Value *dindex = nullptr; + + while (1) { + unsigned nextStart = size1; + + auto dt = TT[{-1}]; + for (size_t i = start; i < size1; ++i) { + auto nex = TT[{(int)i}]; + if ((nex == BaseType::Anything && dt.isFloat()) || + (dt == BaseType::Anything && nex.isFloat())) { + nextStart = i; + break; + } + bool Legal = true; + dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal); + if (!Legal) { + nextStart = i; + break; + } + } + Type *flt = dt.isFloat(); + if (!dt.isKnown()) { + bool found = false; + if (looseTypeAnalysis) { + if (orig_agg->getType()->isFPOrFPVectorTy()) { + flt = orig_agg->getType()->getScalarType(); + found = true; + } else if (orig_agg->getType()->isIntOrIntVectorTy() || + orig_agg->getType()->isPointerTy()) { + flt = nullptr; + found = true; + } + } + if (!found) { + std::string str; + raw_string_ostream ss(str); + ss << "Cannot deduce type of insertvalue agg " << IVI + << " start: " << start << " size: " << size1 + << " TT: " << TT.str(); + if (CustomErrorHandler) { + CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType, + &TR.analyzer, nullptr, wrap(&Builder2)); + } else { + EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI, + ss.str()); + } + } + } + + if (flt) { + if (!dindex) { + auto rule = [&](Value *prediff) { + return Builder2.CreateInsertValue( + prediff, Constant::getNullValue(orig_inserted->getType()), + IVI.getIndices()); + }; + auto prediff = diffe(&IVI, Builder2); + dindex = + applyChainRule(orig_agg->getType(), Builder2, rule, prediff); + } + ((DiffeGradientUtils *)gutils) + ->addToDiffe(orig_agg, dindex, Builder2, flt, start, + nextStart - start); + } + if (nextStart == size1) + break; + start = nextStart; + } } setDiffe(&IVI, diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 387548c225fb..68a7e02d048d 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -47,6 +47,7 @@ #include "llvm/Support/ErrorHandling.h" #include "LibraryFuncs.h" +#include "Utils.h" using namespace llvm; @@ -220,45 +221,120 @@ Value *DiffeGradientUtils::diffe(Value *val, IRBuilder<> &BuilderM) { SmallVector DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, Type *addingType, unsigned start, unsigned size, + llvm::ArrayRef idxs, llvm::Value *mask) { assert(addingType); auto &DL = oldFunc->getParent()->getDataLayout(); - auto storeSize = (DL.getTypeSizeInBits(val->getType()) + 7) / 8; - if (start == 0 && size == storeSize) { - return addToDiffe(val, dif, BuilderM, addingType, ArrayRef(), - mask); + Type *VT = val->getType(); + for (auto cv : idxs) { + auto i = dyn_cast(cv)->getSExtValue(); + if (auto ST = dyn_cast(VT)) { + VT = ST->getElementType(i); + continue; + } + if (auto AT = dyn_cast(VT)) { + assert(i < AT->getNumElements()); + VT = AT->getElementType(); + continue; + } + assert(0 && "illegal indexing type"); + } + auto storeSize = (DL.getTypeSizeInBits(VT) + 7) / 8; + + assert(start < storeSize); + assert(start + size <= storeSize); + + // If VT is a struct type the addToDiffe algorithm will lose type information + // so we do the recurrence here, with full type information. + if (start == 0 && size == storeSize && !isa(VT)) { + if (getWidth() == 1) { + SmallVector eidxs; + for (auto idx : idxs) { + eidxs.push_back((unsigned)cast(idx)->getZExtValue()); + } + return addToDiffe(val, extractMeta(BuilderM, dif, eidxs), BuilderM, + addingType, idxs, mask); + } else { + SmallVector res; + for (unsigned j = 0; j < getWidth(); j++) { + SmallVector lidxs; + SmallVector eidxs = {(unsigned)j}; + lidxs.push_back( + ConstantInt::get(Type::getInt32Ty(val->getContext()), j)); + for (auto idx : idxs) { + eidxs.push_back((unsigned)cast(idx)->getZExtValue()); + lidxs.push_back(idx); + } + for (auto v : addToDiffe(val, extractMeta(BuilderM, dif, eidxs), + BuilderM, addingType, lidxs, mask)) + res.push_back(v); + } + return res; + } } - if (auto ST = dyn_cast(val->getType())) { + if (auto ST = dyn_cast(VT)) { auto SL = DL.getStructLayout(ST); auto left_idx = SL->getElementContainingOffset(start); - assert(SL->getElementOffset(left_idx) == start); auto right_idx = ST->getNumElements(); if (storeSize != start + size) { right_idx = SL->getElementContainingOffset(start + size); - assert(SL->getElementOffset(right_idx) == start + size); + // If this doesn't cleanly end the window, make sure we do a partial + // accumulate for the remaining part in right_idx. + if (SL->getElementOffset(right_idx) != start + size) + right_idx++; } SmallVector res; for (auto i = left_idx; i < right_idx; i++) { - if (getWidth() == 1) { - Value *lidxs[] = { - ConstantInt::get(Type::getInt32Ty(val->getContext()), i)}; - for (auto v : addToDiffe(val, extractMeta(BuilderM, dif, i), BuilderM, - addingType, lidxs, mask)) - res.push_back(v); - } else { - for (int j = 0; j < getWidth(); j++) { - Value *lidxs[] = { - ConstantInt::get(Type::getInt32Ty(val->getContext()), j), - ConstantInt::get(Type::getInt32Ty(val->getContext()), i)}; - unsigned int idxs[] = {(unsigned int)j, (unsigned int)i}; - for (auto v : addToDiffe(val, extractMeta(BuilderM, dif, idxs), - BuilderM, addingType, lidxs, mask)) - res.push_back(v); - } - } + auto subType = ST->getElementType(i); + SmallVector lidxs(idxs.begin(), idxs.end()); + lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i)); + auto sub_start = + (i == left_idx) ? (start - (unsigned)SL->getElementOffset(i)) : 0; + auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8; + auto sub_end = (i == right_idx - 1) + ? min(start + size - (unsigned)SL->getElementOffset(i), + (unsigned)subTypeSize) + : subTypeSize; + for (auto v : addToDiffe(val, dif, BuilderM, addingType, sub_start, + sub_end - sub_start, lidxs, mask)) + res.push_back(v); } return res; } + + if (auto AT = dyn_cast(VT)) { + auto subType = AT->getElementType(); + auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8; + auto left_idx = start / subTypeSize; + auto right_idx = AT->getNumElements(); + if (storeSize != start + size) { + right_idx = (start + size) / subTypeSize; + // If this doesn't cleanly end the window, make sure we do a partial + // accumulate for the remaining part in right_idx. + if (right_idx * subTypeSize != start + size) + right_idx++; + } + SmallVector res; + for (auto i = left_idx; i < right_idx; i++) { + SmallVector lidxs(idxs.begin(), idxs.end()); + lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i)); + auto sub_start = (i == left_idx) ? (start - (i * subTypeSize)) : 0; + auto sub_end = (i == right_idx - 1) + ? min(start + size - (unsigned)(i * subTypeSize), + (unsigned)subTypeSize) + : subTypeSize; + for (auto v : addToDiffe(val, dif, BuilderM, addingType, sub_start, + sub_end - sub_start, lidxs, mask)) + res.push_back(v); + } + return res; + } + + llvm::errs() << " VT: " << *VT << " idxs:{"; + for (auto idx : idxs) + llvm::errs() << *idx << ","; + llvm::errs() << "} start=" << start << " size=" << size + << " storeSize=" << storeSize << " val=" << *val << "\n"; assert(0 && "unhandled accumulate with partial sizes"); } @@ -380,10 +456,14 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, } else { old = BuilderM.CreateLoad(getShadowType(val->getType()), ptr); } + if (dif->getType() != old->getType()) { + llvm::errs() << " val: " << *val << " dif: " << *dif << " old: " << *old + << "\n"; + } assert(dif->getType() == old->getType()); Value *res = nullptr; - if (old->getType()->isIntOrIntVectorTy()) { + if (old->getType()->isIntOrIntVectorTy() || old->getType()->isPointerTy()) { if (!addingType) { if (looseTypeAnalysis) { if (old->getType()->isIntegerTy(64)) @@ -397,6 +477,10 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, llvm::raw_string_ostream ss(s); ss << "oldFunc: " << *oldFunc << "\n"; ss << "Cannot deduce adding type of: " << *val << "\n"; + ss << " + idxs {"; + for (auto idx : idxs) + ss << *idx << ","; + ss << "}\n"; if (CustomErrorHandler) { CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType, &TR.analyzer, nullptr, wrap(&BuilderM)); @@ -451,20 +535,38 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, } } - Value *bcold = BuilderM.CreateBitCast(old, addingType); - Value *bcdif = BuilderM.CreateBitCast(dif, addingType); + Value *bcold = old; + Value *bcdif = dif; + Type *intTy = nullptr; + if (old->getType()->isPointerTy()) { + auto &DL = oldFunc->getParent()->getDataLayout(); + intTy = Type::getIntNTy(old->getContext(), DL.getPointerSizeInBits()); + bcold = BuilderM.CreatePtrToInt(bcold, intTy); + bcdif = BuilderM.CreatePtrToInt(bcdif, intTy); + } else { + intTy = old->getType(); + } + + bcold = BuilderM.CreateBitCast(bcold, addingType); + bcdif = BuilderM.CreateBitCast(bcdif, addingType); res = faddForSelect(bcold, bcdif); if (SelectInst *select = dyn_cast(res)) { assert(addedSelects.back() == select); addedSelects.erase(addedSelects.end() - 1); - res = BuilderM.CreateSelect( - select->getCondition(), - BuilderM.CreateBitCast(select->getTrueValue(), old->getType()), - BuilderM.CreateBitCast(select->getFalseValue(), old->getType())); + + Value *tval = BuilderM.CreateBitCast(select->getTrueValue(), intTy); + Value *fval = BuilderM.CreateBitCast(select->getFalseValue(), intTy); + if (old->getType()->isPointerTy()) { + tval = BuilderM.CreateIntToPtr(tval, old->getType()); + fval = BuilderM.CreateIntToPtr(fval, old->getType()); + } + res = BuilderM.CreateSelect(select->getCondition(), tval, fval); assert(select->getNumUses() == 0); } else { - res = BuilderM.CreateBitCast(res, old->getType()); + res = BuilderM.CreateBitCast(res, intTy); + if (old->getType()->isPointerTy()) + res = BuilderM.CreateIntToPtr(res, old->getType()); } if (!mask) { BuilderM.CreateStore(res, ptr); @@ -541,6 +643,15 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, } return addedSelects; } else { + llvm::errs() << " idx: {"; + for (auto i : idxs) + llvm::errs() << *i << ", "; + llvm::errs() << "}\n"; + if (addingType) + llvm::errs() << " addingType: " << *addingType << "\n"; + else + llvm::errs() << " addingType: null\n"; + llvm::errs() << " oldType:" << *old->getType() << " old:" << *old << "\n"; llvm_unreachable("unknown type to add to diffe"); exit(1); } diff --git a/enzyme/Enzyme/DiffeGradientUtils.h b/enzyme/Enzyme/DiffeGradientUtils.h index eb354494186c..2660fc58ce79 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.h +++ b/enzyme/Enzyme/DiffeGradientUtils.h @@ -97,6 +97,7 @@ class DiffeGradientUtils final : public GradientUtils { llvm::SmallVector addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::Type *addingType, unsigned start, unsigned size, + llvm::ArrayRef idxs = {}, llvm::Value *mask = nullptr); void setDiffe(llvm::Value *val, llvm::Value *toset, diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index ddbde71e6609..7bb5d7be1074 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -2604,22 +2604,56 @@ void TypeAnalyzer::visitInsertValueInst(InsertValueInst &I) { auto g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); g2->accumulateConstantOffset(dl, ai); + delete g2; // Using destructor rather than eraseFromParent // as g2 has no parent + + // Compute the offset at the next logical element [e.g. adding 1 to the last + // index, carrying the value on overflow] + for (ssize_t i = vec.size() - 1; i >= 0; i--) { + auto CI = cast(vec[i]); + auto val = CI->getZExtValue(); + if (i == 0) { + vec[i] = ConstantInt::get(CI->getType(), val + 1); + break; + } + auto subTy = GetElementPtrInst::getIndexedType( + I.getOperand(0)->getType(), ArrayRef(vec).slice(0, i)); + if (auto ST = dyn_cast(subTy)) { + if (val + 1 == ST->getNumElements()) { + vec.erase(vec.begin() + i, vec.end()); + continue; + } + vec[i] = ConstantInt::get(CI->getType(), val + 1); + break; + } else { + auto AT = cast(subTy); + if (val + 1 == AT->getNumElements()) { + vec.erase(vec.begin() + i, vec.end()); + continue; + } + vec[i] = ConstantInt::get(CI->getType(), val + 1); + break; + } + } + g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec); + APInt aiend(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); + g2->accumulateConstantOffset(dl, aiend); delete g2; int off = (int)ai.getLimitedValue(); - int agg_size = dl.getTypeSizeInBits(I.getType()) / 8; - int ins_size = - dl.getTypeSizeInBits(I.getInsertedValueOperand()->getType()) / 8; + int agg_size = (dl.getTypeSizeInBits(I.getType()) + 7) / 8; + int ins_size = (int)(aiend - ai).getLimitedValue(); + int ins2_size = + (dl.getTypeSizeInBits(I.getInsertedValueOperand()->getType()) + 7) / 8; if (direction & UP) updateAnalysis(I.getAggregateOperand(), getAnalysis(&I).Clear(off, off + ins_size, agg_size), &I); if (direction & UP) updateAnalysis(I.getInsertedValueOperand(), - getAnalysis(&I).ShiftIndices(dl, off, ins_size, 0), &I); + getAnalysis(&I).ShiftIndices(dl, off, ins2_size, 0), &I); auto new_res = getAnalysis(I.getAggregateOperand()).Clear(off, off + ins_size, agg_size); auto shifted = getAnalysis(I.getInsertedValueOperand()) diff --git a/enzyme/test/Enzyme/ReverseMode/insertuw.ll b/enzyme/test/Enzyme/ReverseMode/insertuw.ll index 50796276b7e8..711a1c3701c7 100644 --- a/enzyme/test/Enzyme/ReverseMode/insertuw.ll +++ b/enzyme/test/Enzyme/ReverseMode/insertuw.ll @@ -85,7 +85,7 @@ declare void @__enzyme_autodiff(...) ; CHECK-NEXT: %[[i9:.+]] = fadd fast double %[[i8]], %[[i2]] ; CHECK-NEXT: store double %[[i9]], double* %[[i7]] ; CHECK-NEXT: %[[i10:.+]] = load { double, double, double* }, { double, double, double* }* %"out1'de" -; CHECK-NEXT: %[[i12:.+]] = load { double, double, double* }, { double, double, double* }* %"x1'de" + ; CHECK-NEXT: %[[i13:.+]] = extractvalue { double, double, double* } %[[i10]], 0 ; CHECK-NEXT: %[[i14:.+]] = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x1'de", i32 0, i32 0 ; CHECK-NEXT: %[[i15:.+]] = load double, double* %[[i14]] @@ -97,20 +97,17 @@ declare void @__enzyme_autodiff(...) ; CHECK-NEXT: %[[i20:.+]] = fadd fast double %[[i19]], %[[i17]] ; CHECK-NEXT: store double %[[i20]], double* %[[i18]] ; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"out1'de" + ; CHECK-NEXT: %[[i21:.+]] = load { double, double, double* }, { double, double, double* }* %"x1'de" ; CHECK-NEXT: %[[i22:.+]] = extractvalue { double, double, double* } %[[i21]], 1 ; CHECK-NEXT: %[[i23:.+]] = fadd fast double 0.000000e+00, %[[i22]] ; CHECK-NEXT: %[[i24:.+]] = load { double, double, double* }, { double, double, double* }* %"x1'de" -; CHECK-NEXT: %[[i26:.+]] = load { double, double, double* }, { double, double, double* }* %"x0'de" + ; CHECK-NEXT: %[[i27:.+]] = extractvalue { double, double, double* } %[[i24]], 0 ; CHECK-NEXT: %[[i28:.+]] = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x0'de", i32 0, i32 0 ; CHECK-NEXT: %[[i29:.+]] = load double, double* %[[i28]] ; CHECK-NEXT: %[[i30:.+]] = fadd fast double %[[i29]], %[[i27]] ; CHECK-NEXT: store double %[[i30]], double* %[[i28]] -; CHECK-NEXT: %[[i31:.+]] = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x0'de", i32 0, i32 1 -; CHECK-NEXT: %[[i32:.+]] = load double, double* %[[i31]] -; CHECK-NEXT: %[[i33:.+]] = fadd fast double %[[i32]], 0.000000e+00 -; CHECK-NEXT: store double %[[i33]], double* %[[i31]] ; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"x1'de" ; CHECK-NEXT: store double 0.000000e+00, double* %"in1'" ; CHECK-NEXT: %[[i34:.+]] = load double, double* %"in1'" diff --git a/enzyme/test/Enzyme/ReverseMode/insertuw2.ll b/enzyme/test/Enzyme/ReverseMode/insertuw2.ll index 43f278591769..b43361bf01d7 100644 --- a/enzyme/test/Enzyme/ReverseMode/insertuw2.ll +++ b/enzyme/test/Enzyme/ReverseMode/insertuw2.ll @@ -89,7 +89,6 @@ declare void @__enzyme_autodiff(...) ; CHECK-NEXT: %[[i9:.+]] = fadd fast double %[[i8]], %[[i2]] ; CHECK-NEXT: store double %[[i9]], double* %[[i7]] ; CHECK-NEXT: %[[i10:.+]] = load { double, double, double* }, { double, double, double* }* %"out2'de" -; CHECK-NEXT: %[[i12:.+]] = load { double, double, double* }, { double, double, double* }* %"out1'de" ; CHECK-NEXT: %[[i13:.+]] = extractvalue { double, double, double* } %[[i10]], 0 ; CHECK-NEXT: %[[i14:.+]] = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"out1'de", i32 0, i32 0 ; CHECK-NEXT: %[[i15:.+]] = load double, double* %[[i14]] @@ -101,7 +100,6 @@ declare void @__enzyme_autodiff(...) ; CHECK-NEXT: store double %[[i19]], double* %[[i17]] ; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"out2'de" ; CHECK-NEXT: %[[i20:.+]] = load { double, double, double* }, { double, double, double* }* %"out1'de" -; CHECK-NEXT: %[[i22:.+]] = load { double, double, double* }, { double, double, double* }* %"x1'de" ; CHECK-NEXT: %[[i23:.+]] = extractvalue { double, double, double* } %[[i20]], 0 ; CHECK-NEXT: %[[i24:.+]] = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x1'de", i32 0, i32 0 ; CHECK-NEXT: %[[i25:.+]] = load double, double* %[[i24]] @@ -117,16 +115,11 @@ declare void @__enzyme_autodiff(...) ; CHECK-NEXT: %[[i32:.+]] = extractvalue { double, double, double* } %[[i31]], 1 ; CHECK-NEXT: %[[i33:.+]] = fadd fast double 0.000000e+00, %[[i32]] ; CHECK-NEXT: %[[i34:.+]] = load { double, double, double* }, { double, double, double* }* %"x1'de" -; CHECK-NEXT: %[[i36:.+]] = load { double, double, double* }, { double, double, double* }* %"x0'de" ; CHECK-NEXT: %[[i37:.+]] = extractvalue { double, double, double* } %[[i34]], 0 ; CHECK-NEXT: %[[i38:.+]] = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x0'de", i32 0, i32 0 ; CHECK-NEXT: %[[i39:.+]] = load double, double* %[[i38]] ; CHECK-NEXT: %[[i40:.+]] = fadd fast double %[[i39]], %[[i37]] ; CHECK-NEXT: store double %[[i40]], double* %[[i38]] -; CHECK-NEXT: %[[i41:.+]] = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x0'de", i32 0, i32 1 -; CHECK-NEXT: %[[i42:.+]] = load double, double* %[[i41]] -; CHECK-NEXT: %[[i43:.+]] = fadd fast double %[[i42]], 0.000000e+00 -; CHECK-NEXT: store double %[[i43]], double* %[[i41]] ; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"x1'de" ; CHECK-NEXT: store double 0.000000e+00, double* %"in1'" ; CHECK-NEXT: %[[i44:.+]] = load double, double* %"in1'" diff --git a/enzyme/test/Enzyme/ReverseMode/insertvalue.ll b/enzyme/test/Enzyme/ReverseMode/insertvalue.ll index 699caedcd682..5d7464db1407 100644 --- a/enzyme/test/Enzyme/ReverseMode/insertvalue.ll +++ b/enzyme/test/Enzyme/ReverseMode/insertvalue.ll @@ -39,14 +39,6 @@ declare double @__enzyme_autodiff(double (double)*, ...) ; CHECK-NEXT: %[[i9:.+]] = load double, double* %[[i8]] ; CHECK-NEXT: %[[i10:.+]] = fadd fast double %[[i9]], %[[i7]] ; CHECK-NEXT: store double %[[i10]], double* %[[i8]] -; CHECK-NEXT: %[[i11:.+]] = getelementptr inbounds [3 x double], [3 x double]* %"agg1'de", i32 0, i32 1 -; CHECK-NEXT: %[[i12:.+]] = load double, double* %[[i11]] -; CHECK-NEXT: store double %[[i12]], double* %[[i11]] -; CHECK-NEXT: %[[i13:.+]] = extractvalue [3 x double] %[[i5]], 2 -; CHECK-NEXT: %[[i14:.+]] = getelementptr inbounds [3 x double], [3 x double]* %"agg1'de", i32 0, i32 2 -; CHECK-NEXT: %[[i15:.+]] = load double, double* %[[i14]] -; CHECK-NEXT: %[[i16:.+]] = fadd fast double %[[i15]], %[[i13]] -; CHECK-NEXT: store double %[[i16]], double* %[[i14]] ; CHECK-NEXT: store [3 x double] zeroinitializer, [3 x double]* %"agg2'de" ; CHECK-NEXT: %[[m0diffex:.+]] = fmul fast double %4, %x ; CHECK-NEXT: %[[m1diffex:.+]] = fmul fast double %4, %x diff --git a/enzyme/test/Enzyme/ReverseMode/multistore2.ll b/enzyme/test/Enzyme/ReverseMode/multistore2.ll index f64100f46254..72da4cde72b9 100644 --- a/enzyme/test/Enzyme/ReverseMode/multistore2.ll +++ b/enzyme/test/Enzyme/ReverseMode/multistore2.ll @@ -21,6 +21,7 @@ entry: ; CHECK: define internal { double } @diffef({ double, i1 }* %y, { double, i1 }* %"y'", double %x, i1 %z) ; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = alloca { double, i1 }, align 8 ; CHECK-NEXT: %"ins2'de" = alloca { double, i1 }, align 8 ; CHECK-NEXT: store { double, i1 } zeroinitializer, { double, i1 }* %"ins2'de", align 8 ; CHECK-NEXT: %"ins'de" = alloca { double, i1 }, align 8 @@ -28,26 +29,30 @@ entry: ; CHECK-NEXT: %ins = insertvalue { double, i1 } undef, double %x, 0 ; CHECK-NEXT: %ins2 = insertvalue { double, i1 } %ins, i1 %z, 1 ; CHECK-NEXT: store { double, i1 } %ins2, { double, i1 }* %y, align 8 -; CHECK-NEXT: %0 = load { double, i1 }, { double, i1 }* %"y'", align 8 -; CHECK-NEXT: store { double, i1 } zeroinitializer, { double, i1 }* %"y'", align 8 -; CHECK-NEXT: %1 = load { double, i1 }, { double, i1 }* %"ins2'de", align 8 -; CHECK-NEXT: %2 = extractvalue { double, i1 } %0, 0 -; CHECK-NEXT: %3 = getelementptr inbounds { double, i1 }, { double, i1 }* %"ins2'de", i32 0, i32 0 -; CHECK-NEXT: %4 = load double, double* %3, align 8 -; CHECK-NEXT: %5 = fadd fast double %4, %2 -; CHECK-NEXT: store double %5, double* %3, align 8 -; CHECK-NEXT: %6 = load { double, i1 }, { double, i1 }* %"ins2'de", align 8 -; CHECK-NEXT: %7 = load { double, i1 }, { double, i1 }* %"ins'de", align 8 -; CHECK-NEXT: %8 = extractvalue { double, i1 } %6, 0 -; CHECK-NEXT: %9 = getelementptr inbounds { double, i1 }, { double, i1 }* %"ins'de", i32 0, i32 0 -; CHECK-NEXT: %10 = load double, double* %9, align 8 -; CHECK-NEXT: %11 = fadd fast double %10, %8 -; CHECK-NEXT: store double %11, double* %9, align 8 +; CHECK-NEXT: %[[i0:.+]] = load { double, i1 }, { double, i1 }* %"y'", align 8 +; CHECK-NEXT: store { double, i1 } zeroinitializer, { double, i1 }* %0, align 8 + +; CHECK-NEXT: %2 = bitcast { double, i1 }* %"y'" to i64* +; CHECK-NEXT: %3 = bitcast { double, i1 }* %0 to i64* +; CHECK-NEXT: %4 = load i64, i64* %3, align 4 +; CHECK-NEXT: store i64 %4, i64* %2, align 8 + +; CHECK-NEXT: %[[i2:.+]] = extractvalue { double, i1 } %[[i0]], 0 +; CHECK-NEXT: %[[i3:.+]] = getelementptr inbounds { double, i1 }, { double, i1 }* %"ins2'de", i32 0, i32 0 +; CHECK-NEXT: %[[i4:.+]] = load double, double* %[[i3]], align 8 +; CHECK-NEXT: %[[i5:.+]] = fadd fast double %[[i4]], %[[i2]] +; CHECK-NEXT: store double %[[i5]], double* %[[i3]], align 8 +; CHECK-NEXT: %[[i6:.+]] = load { double, i1 }, { double, i1 }* %"ins2'de", align 8 +; CHECK-NEXT: %[[i8:.+]] = extractvalue { double, i1 } %[[i6]], 0 +; CHECK-NEXT: %[[i9:.+]] = getelementptr inbounds { double, i1 }, { double, i1 }* %"ins'de", i32 0, i32 0 +; CHECK-NEXT: %[[i10:.+]] = load double, double* %[[i9]], align 8 +; CHECK-NEXT: %[[i11:.+]] = fadd fast double %[[i10]], %[[i8]] +; CHECK-NEXT: store double %[[i11]], double* %[[i9]], align 8 ; CHECK-NEXT: store { double, i1 } zeroinitializer, { double, i1 }* %"ins2'de", align 8 -; CHECK-NEXT: %12 = load { double, i1 }, { double, i1 }* %"ins'de", align 8 -; CHECK-NEXT: %13 = extractvalue { double, i1 } %12, 0 -; CHECK-NEXT: %14 = fadd fast double 0.000000e+00, %13 +; CHECK-NEXT: %[[i12:.+]] = load { double, i1 }, { double, i1 }* %"ins'de", align 8 +; CHECK-NEXT: %[[i13:.+]] = extractvalue { double, i1 } %[[i12]], 0 +; CHECK-NEXT: %[[i14:.+]] = fadd fast double 0.000000e+00, %[[i13]] ; CHECK-NEXT: store { double, i1 } zeroinitializer, { double, i1 }* %"ins'de", align 8 -; CHECK-NEXT: %15 = insertvalue { double } undef, double %14, 0 -; CHECK-NEXT: ret { double } %15 +; CHECK-NEXT: %[[i15:.+]] = insertvalue { double } undef, double %[[i14]], 0 +; CHECK-NEXT: ret { double } %[[i15]] ; CHECK-NEXT: } diff --git a/enzyme/test/TypeAnalysis/fwdiv.ll b/enzyme/test/TypeAnalysis/fwdiv.ll index 4c2f8b36fc8b..b4c5e1cd62ef 100644 --- a/enzyme/test/TypeAnalysis/fwdiv.ll +++ b/enzyme/test/TypeAnalysis/fwdiv.ll @@ -29,7 +29,7 @@ entry: ; CHECK-NEXT: entry ; CHECK-NEXT: %ld = load i64, i64* %p, align 8, !tbaa !2: {[-1]:Float@double} ; CHECK-NEXT: %iv1 = insertvalue { i32, i64 } undef, i64 %ld, 1: {[0]:Anything, [1]:Anything, [2]:Anything, [3]:Anything, [4]:Anything, [5]:Anything, [6]:Anything, [7]:Anything, [8]:Float@double} -; CHECK-NEXT: %iv2 = insertvalue { i32, i64 } %iv1, i32 4, 0: {[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Anything, [5]:Anything, [6]:Anything, [7]:Anything, [8]:Float@double} +; CHECK-NEXT: %iv2 = insertvalue { i32, i64 } %iv1, i32 4, 0: {[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double} ; CHECK-NEXT: %ev1 = extractvalue { i32, i64 } %iv2, 0: {[-1]:Integer} ; CHECK-NEXT: %ev2 = extractvalue { i32, i64 } %iv2, 1: {[-1]:Float@double} ; CHECK-NEXT: ret void: {} diff --git a/enzyme/test/TypeAnalysis/jlins.ll b/enzyme/test/TypeAnalysis/jlins.ll index d0417e6020e1..a04078ce7979 100644 --- a/enzyme/test/TypeAnalysis/jlins.ll +++ b/enzyme/test/TypeAnalysis/jlins.ll @@ -14,6 +14,6 @@ bb: ; CHECK: a0 - {} | ; CHECK-NEXT: bb -; CHECK-NEXT: %i = insertvalue { { i1, i1, i1 }, i8 } {{(undef|poison)}}, i1 false, 0, 0: {[-1]:Anything} -; CHECK-NEXT: %i2 = insertvalue { { i1, i1, i1 }, i8 } %i, i8 3, 1: {[0]:Anything, [1]:Anything, [2]:Anything, [3]:Integer} +; CHECK-NEXT: %i = insertvalue { { i1, i1, i1 }, i8 } {{(undef|poison)}}, i1 false, 0, 0: {[0]:Integer, [1]:Anything, [2]:Anything, [3]:Anything} +; CHECK-NEXT: %i2 = insertvalue { { i1, i1, i1 }, i8 } %i, i8 3, 1: {[0]:Integer, [1]:Anything, [2]:Anything, [3]:Integer} ; CHECK-NEXT: ret void: {}